Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8a044754
Unverified
Commit
8a044754
authored
Aug 26, 2025
by
Chaojun Zhang
Committed by
GitHub
Aug 25, 2025
Browse files
[XPU] Delay BF16 check to worker init for spawn compatibility (#22979)
Signed-off-by:
chzhang
<
chaojun.zhang@intel.com
>
parent
9188ae7c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
47 deletions
+60
-47
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+20
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+7
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+20
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+11
-26
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-21
vllm/v1/worker/xpu_worker.py
vllm/v1/worker/xpu_worker.py
+1
-0
No files found.
vllm/platforms/cuda.py
View file @
8a044754
...
@@ -518,6 +518,26 @@ class CudaPlatformBase(Platform):
...
@@ -518,6 +518,26 @@ class CudaPlatformBase(Platform):
supported
=
True
supported
=
True
return
supported
return
supported
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_dtype
:
torch
.
dtype
):
if
torch_dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
not
cls
.
has_device_capability
(
80
):
capability
=
cls
.
get_device_capability
()
gpu_name
=
cls
.
get_device_name
()
if
capability
is
None
:
compute_str
=
"does not have a compute capability"
else
:
version_str
=
capability
.
as_version_str
()
compute_str
=
f
"has compute capability
{
version_str
}
"
raise
ValueError
(
"Bfloat16 is only supported on GPUs "
"with compute capability of at least 8.0. "
f
"Your
{
gpu_name
}
GPU
{
compute_str
}
. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half."
)
# NVML utils
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
...
...
vllm/platforms/interface.py
View file @
8a044754
...
@@ -572,6 +572,13 @@ class Platform:
...
@@ -572,6 +572,13 @@ class Platform:
"""
"""
return
False
return
False
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_dtype
:
torch
.
dtype
):
"""
Check if the dtype is supported by the current platform.
"""
raise
NotImplementedError
class
UnspecifiedPlatform
(
Platform
):
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/rocm.py
View file @
8a044754
...
@@ -462,3 +462,23 @@ class RocmPlatform(Platform):
...
@@ -462,3 +462,23 @@ class RocmPlatform(Platform):
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
model_config
:
"ModelConfig"
)
->
bool
:
model_config
:
"ModelConfig"
)
->
bool
:
return
True
return
True
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_dtype
:
torch
.
dtype
):
if
torch_dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
not
cls
.
has_device_capability
(
80
):
capability
=
cls
.
get_device_capability
()
gpu_name
=
cls
.
get_device_name
()
if
capability
is
None
:
compute_str
=
"does not have a compute capability"
else
:
version_str
=
capability
.
as_version_str
()
compute_str
=
f
"has compute capability
{
version_str
}
"
raise
ValueError
(
"Bfloat16 is only supported on GPUs "
"with compute capability of at least 8.0. "
f
"Your
{
gpu_name
}
GPU
{
compute_str
}
. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half."
)
vllm/platforms/xpu.py
View file @
8a044754
...
@@ -97,13 +97,6 @@ class XPUPlatform(Platform):
...
@@ -97,13 +97,6 @@ class XPUPlatform(Platform):
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationLevel
vllm_config
.
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
# noqa: E501
vllm_config
.
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
# noqa: E501
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
if
model_config
is
not
None
and
model_config
.
dtype
==
torch
.
bfloat16
\
and
not
cls
.
device_support_bf16
():
model_config
.
dtype
=
torch
.
float16
# lazy import to avoid circular import
# lazy import to avoid circular import
from
vllm.config
import
CUDAGraphMode
from
vllm.config
import
CUDAGraphMode
compilation_config
=
vllm_config
.
compilation_config
compilation_config
=
vllm_config
.
compilation_config
...
@@ -162,30 +155,11 @@ class XPUPlatform(Platform):
...
@@ -162,30 +155,11 @@ class XPUPlatform(Platform):
torch
.
xpu
.
reset_peak_memory_stats
(
device
)
torch
.
xpu
.
reset_peak_memory_stats
(
device
)
return
torch
.
xpu
.
max_memory_allocated
(
device
)
return
torch
.
xpu
.
max_memory_allocated
(
device
)
@
classmethod
def
device_support_bf16
(
cls
)
->
bool
:
device_name
=
cls
.
get_device_name
().
lower
()
if
cls
.
is_client_gpu_a770
():
logger
.
warning
(
"Intel Arc A770 have bfloat16 accuracy known issue,"
" fallback to float16"
)
return
False
else
:
logger
.
info
(
"Device name %s supports bfloat16. Please file an issue "
"if you encounter any accuracy problems with bfloat16."
,
device_name
)
return
True
@
classmethod
@
classmethod
def
is_data_center_gpu
(
cls
)
->
bool
:
def
is_data_center_gpu
(
cls
)
->
bool
:
device_name
=
cls
.
get_device_name
().
lower
()
device_name
=
cls
.
get_device_name
().
lower
()
return
device_name
.
count
(
"data center gpu"
)
>
0
return
device_name
.
count
(
"data center gpu"
)
>
0
@
classmethod
def
is_client_gpu_a770
(
cls
)
->
bool
:
device_name
=
cls
.
get_device_name
().
lower
()
return
device_name
.
count
(
"a770"
)
>
0
@
classmethod
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"
# noqa
return
"vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"
# noqa
...
@@ -197,3 +171,14 @@ class XPUPlatform(Platform):
...
@@ -197,3 +171,14 @@ class XPUPlatform(Platform):
@
classmethod
@
classmethod
def
device_count
(
cls
)
->
int
:
def
device_count
(
cls
)
->
int
:
return
torch
.
xpu
.
device_count
()
return
torch
.
xpu
.
device_count
()
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_dtype
:
torch
.
dtype
):
if
torch_dtype
==
torch
.
bfloat16
:
# noqa: SIM102
device_name
=
cls
.
get_device_name
().
lower
()
# client gpu a770
if
device_name
.
count
(
"a770"
)
>
0
:
raise
ValueError
(
"Intel Arc A770 have bfloat16 accuracy known issue. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half."
)
vllm/v1/worker/gpu_worker.py
View file @
8a044754
...
@@ -167,7 +167,7 @@ class Worker(WorkerBase):
...
@@ -167,7 +167,7 @@ class Worker(WorkerBase):
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
current_platform
.
set_device
(
self
.
device
)
current_platform
.
set_device
(
self
.
device
)
_
check_if_
gpu_
supports_dtype
(
self
.
model_config
.
dtype
)
current_platform
.
check_if_supports_dtype
(
self
.
model_config
.
dtype
)
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -612,23 +612,3 @@ def init_worker_distributed_environment(
...
@@ -612,23 +612,3 @@ def init_worker_distributed_environment(
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
ensure_kv_transfer_initialized
(
vllm_config
)
ensure_kv_transfer_initialized
(
vllm_config
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
not
current_platform
.
has_device_capability
(
80
):
capability
=
current_platform
.
get_device_capability
()
gpu_name
=
current_platform
.
get_device_name
()
if
capability
is
None
:
compute_str
=
"does not have a compute capability"
else
:
version_str
=
capability
.
as_version_str
()
compute_str
=
f
"has compute capability
{
version_str
}
"
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU
{
compute_str
}
. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half."
)
vllm/v1/worker/xpu_worker.py
View file @
8a044754
...
@@ -145,6 +145,7 @@ class XPUWorker(Worker):
...
@@ -145,6 +145,7 @@ class XPUWorker(Worker):
):
):
self
.
device
=
torch
.
device
(
f
"xpu:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"xpu:
{
self
.
local_rank
}
"
)
current_platform
.
set_device
(
self
.
device
)
current_platform
.
set_device
(
self
.
device
)
current_platform
.
check_if_supports_dtype
(
self
.
model_config
.
dtype
)
torch
.
xpu
.
empty_cache
()
torch
.
xpu
.
empty_cache
()
self
.
init_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
self
.
init_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
self
.
local_rank
).
total_memory
self
.
local_rank
).
total_memory
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment