Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ffa3f31
Unverified
Commit
6ffa3f31
authored
Sep 18, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 18, 2024
Browse files
[CI/Build] Avoid CUDA initialization (#8534)
parent
e3515729
Changes
55
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
135 additions
and
61 deletions
+135
-61
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+6
-4
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+1
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+3
-2
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+3
-3
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-1
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+3
-7
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+4
-4
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+9
-8
vllm/platforms/interface.py
vllm/platforms/interface.py
+55
-7
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+7
-7
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+6
-2
vllm/prompt_adapter/utils.py
vllm/prompt_adapter/utils.py
+3
-1
vllm/usage/usage_lib.py
vllm/usage/usage_lib.py
+2
-1
vllm/utils.py
vllm/utils.py
+21
-7
vllm/worker/worker.py
vllm/worker/worker.py
+11
-5
No files found.
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
6ffa3f31
...
@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
...
@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability
:
Optional
[
int
]
=
None
device_capability
:
Optional
[
int
]
=
None
):
):
if
device_capability
is
None
:
if
device_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
if
device_capability
<
80
:
if
device_capability
<
80
:
return
[]
return
[]
...
@@ -52,8 +53,9 @@ def _check_marlin_supported(
...
@@ -52,8 +53,9 @@ def _check_marlin_supported(
device_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
device_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
device_capability
is
None
:
if
device_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
supported_types
=
query_marlin_supported_quant_types
(
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
device_capability
)
has_zp
,
device_capability
)
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
6ffa3f31
...
@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
...
@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def
is_fp8_marlin_supported
():
def
is_fp8_marlin_supported
():
capability
=
current_platform
.
get_device_capability
()
return
current_platform
.
has_device_capability
(
80
)
return
capability
[
0
]
>=
8
def
apply_fp8_marlin_linear
(
def
apply_fp8_marlin_linear
(
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
6ffa3f31
...
@@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool:
...
@@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
# cutlass is not supported on Rocm
if
is_hip
():
if
is_hip
():
return
False
return
False
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
...
...
vllm/model_executor/model_loader/loader.py
View file @
6ffa3f31
...
@@ -97,10 +97,10 @@ def _get_quantization_config(
...
@@ -97,10 +97,10 @@ def _get_quantization_config(
"""Get the quantization config."""
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
current_platform
.
get_device_capability
()
# type: ignore
capability
_tuple
=
current_platform
.
get_device_capability
()
if
capability
is
not
None
:
if
capability
_tuple
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
_tuple
.
to_int
()
if
capability
<
quant_config
.
get_min_capability
():
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
"
f
"The quantization method
{
model_config
.
quantization
}
"
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
6ffa3f31
...
@@ -207,7 +207,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -207,7 +207,7 @@ class Qwen2VisionAttention(nn.Module):
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
get
_device_capability
(
)[
0
]
>=
8
device_available
=
current_platform
.
has
_device_capability
(
80
)
if
device_available
:
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
from
transformers.utils
import
is_flash_attn_2_available
...
...
vllm/model_executor/utils.py
View file @
6ffa3f31
"""Utils for model executor."""
"""Utils for model executor."""
import
random
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
torch
import
torch
from
vllm.utils
import
seed_everything
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
seed_everything
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
def
set_weight_attrs
(
def
set_weight_attrs
(
...
...
vllm/platforms/cpu.py
View file @
6ffa3f31
...
@@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum
...
@@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum
class
CpuPlatform
(
Platform
):
class
CpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
CPU
_enum
=
PlatformEnum
.
CPU
@
static
method
@
class
method
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
"cpu"
return
"cpu"
@
static
method
@
class
method
def
inference_mode
():
def
inference_mode
(
cls
):
return
torch
.
no_grad
()
return
torch
.
no_grad
()
vllm/platforms/cuda.py
View file @
6ffa3f31
...
@@ -11,7 +11,7 @@ from typing_extensions import ParamSpec
...
@@ -11,7 +11,7 @@ from typing_extensions import ParamSpec
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int:
...
@@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class
CudaPlatform
(
Platform
):
class
CudaPlatform
(
Platform
):
_enum
=
PlatformEnum
.
CUDA
_enum
=
PlatformEnum
.
CUDA
@
static
method
@
class
method
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]
:
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
return
get_physical_device_capability
(
physical_device_id
)
major
,
minor
=
get_physical_device_capability
(
physical_device_id
)
return
DeviceCapability
(
major
=
major
,
minor
=
minor
)
@
static
method
@
class
method
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
return
get_physical_device_name
(
physical_device_id
)
return
get_physical_device_name
(
physical_device_id
)
@
static
method
@
class
method
@
with_nvml_context
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
])
->
bool
:
def
is_full_nvlink
(
cls
,
physical_device_ids
:
List
[
int
])
->
bool
:
"""
"""
query if the set of gpus are fully connected by nvlink (1 hop)
query if the set of gpus are fully connected by nvlink (1 hop)
"""
"""
...
...
vllm/platforms/interface.py
View file @
6ffa3f31
import
enum
import
enum
from
typing
import
Optional
,
Tuple
from
typing
import
NamedTuple
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum):
...
@@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum):
UNSPECIFIED
=
enum
.
auto
()
UNSPECIFIED
=
enum
.
auto
()
class
DeviceCapability
(
NamedTuple
):
major
:
int
minor
:
int
def
as_version_str
(
self
)
->
str
:
return
f
"
{
self
.
major
}
.
{
self
.
minor
}
"
def
to_int
(
self
)
->
int
:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert
0
<=
self
.
minor
<
10
return
self
.
major
*
10
+
self
.
minor
class
Platform
:
class
Platform
:
_enum
:
PlatformEnum
_enum
:
PlatformEnum
...
@@ -27,16 +44,47 @@ class Platform:
...
@@ -27,16 +44,47 @@ class Platform:
def
is_cpu
(
self
)
->
bool
:
def
is_cpu
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
CPU
return
self
.
_enum
==
PlatformEnum
.
CPU
@
staticmethod
def
is_cuda_alike
(
self
)
->
bool
:
def
get_device_capability
(
device_id
:
int
=
0
)
->
Optional
[
Tuple
[
int
,
int
]]:
"""Stateless version of :func:`torch.cuda.is_available`."""
return
self
.
_enum
in
(
PlatformEnum
.
CUDA
,
PlatformEnum
.
ROCM
)
@
classmethod
def
get_device_capability
(
cls
,
device_id
:
int
=
0
,
)
->
Optional
[
DeviceCapability
]:
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
return
None
return
None
@
staticmethod
@
classmethod
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
def
has_device_capability
(
cls
,
capability
:
Union
[
Tuple
[
int
,
int
],
int
],
device_id
:
int
=
0
,
)
->
bool
:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
"""
current_capability
=
cls
.
get_device_capability
(
device_id
=
device_id
)
if
current_capability
is
None
:
return
False
if
isinstance
(
capability
,
tuple
):
return
current_capability
>=
capability
return
current_capability
.
to_int
()
>=
capability
@
classmethod
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
raise
NotImplementedError
raise
NotImplementedError
@
static
method
@
class
method
def
inference_mode
():
def
inference_mode
(
cls
):
"""A device-specific wrapper of `torch.inference_mode`.
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
This wrapper is recommended because some hardware backends such as TPU
...
...
vllm/platforms/rocm.py
View file @
6ffa3f31
import
os
import
os
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Tuple
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
...
@@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class
RocmPlatform
(
Platform
):
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
_enum
=
PlatformEnum
.
ROCM
@
static
method
@
class
method
@
lru_cache
(
maxsize
=
8
)
@
lru_cache
(
maxsize
=
8
)
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
return
torch
.
cuda
.
get_device_capability
(
device_id
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
device_id
)
return
DeviceCapability
(
major
=
major
,
minor
=
minor
)
@
static
method
@
class
method
@
lru_cache
(
maxsize
=
8
)
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
torch
.
cuda
.
get_device_name
(
device_id
)
return
torch
.
cuda
.
get_device_name
(
device_id
)
vllm/platforms/tpu.py
View file @
6ffa3f31
...
@@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum
...
@@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum
class
TpuPlatform
(
Platform
):
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
_enum
=
PlatformEnum
.
TPU
@
staticmethod
@
classmethod
def
inference_mode
():
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
raise
NotImplementedError
@
classmethod
def
inference_mode
(
cls
):
return
torch
.
no_grad
()
return
torch
.
no_grad
()
vllm/prompt_adapter/utils.py
View file @
6ffa3f31
...
@@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download
...
@@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download
from
huggingface_hub.utils
import
EntryNotFoundError
from
huggingface_hub.utils
import
EntryNotFoundError
from
safetensors.torch
import
load_file
as
safe_load_file
from
safetensors.torch
import
load_file
as
safe_load_file
from
vllm.platforms
import
current_platform
WEIGHTS_NAME
=
"adapter_model.bin"
WEIGHTS_NAME
=
"adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME
=
"adapter_model.safetensors"
SAFETENSORS_WEIGHTS_NAME
=
"adapter_model.safetensors"
# Get current device name based on available devices
# Get current device name based on available devices
def
infer_device
()
->
str
:
def
infer_device
()
->
str
:
if
torch
.
cuda
.
is_availabl
e
():
if
current_platform
.
is_cuda_alik
e
():
return
"cuda"
return
"cuda"
return
"cpu"
return
"cpu"
...
...
vllm/usage/usage_lib.py
View file @
6ffa3f31
...
@@ -17,6 +17,7 @@ import torch
...
@@ -17,6 +17,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.connections
import
global_http_connection
from
vllm.connections
import
global_http_connection
from
vllm.platforms
import
current_platform
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
_config_home
=
envs
.
VLLM_CONFIG_ROOT
_config_home
=
envs
.
VLLM_CONFIG_ROOT
...
@@ -151,7 +152,7 @@ class UsageMessage:
...
@@ -151,7 +152,7 @@ class UsageMessage:
usage_context
:
UsageContext
,
usage_context
:
UsageContext
,
extra_kvs
:
Dict
[
str
,
Any
])
->
None
:
extra_kvs
:
Dict
[
str
,
Any
])
->
None
:
# Platform information
# Platform information
if
torch
.
cuda
.
is_availabl
e
():
if
current_platform
.
is_cuda_alik
e
():
device_property
=
torch
.
cuda
.
get_device_properties
(
0
)
device_property
=
torch
.
cuda
.
get_device_properties
(
0
)
self
.
gpu_count
=
torch
.
cuda
.
device_count
()
self
.
gpu_count
=
torch
.
cuda
.
device_count
()
self
.
gpu_type
=
device_property
.
name
self
.
gpu_type
=
device_property
.
name
...
...
vllm/utils.py
View file @
6ffa3f31
...
@@ -5,6 +5,7 @@ import datetime
...
@@ -5,6 +5,7 @@ import datetime
import
enum
import
enum
import
gc
import
gc
import
os
import
os
import
random
import
socket
import
socket
import
subprocess
import
subprocess
import
sys
import
sys
...
@@ -32,6 +33,7 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
...
@@ -32,6 +33,7 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -373,6 +375,22 @@ def get_cpu_memory() -> int:
...
@@ -373,6 +375,22 @@ def get_cpu_memory() -> int:
return
psutil
.
virtual_memory
().
total
return
psutil
.
virtual_memory
().
total
def
seed_everything
(
seed
:
int
)
->
None
:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
manual_seed_all
(
seed
)
if
is_xpu
():
torch
.
xpu
.
manual_seed_all
(
seed
)
def
random_uuid
()
->
str
:
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
...
@@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash(
...
@@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash(
seed
:
int
=
0
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
...
@@ -678,9 +694,7 @@ def create_kv_caches_with_random(
...
@@ -678,9 +694,7 @@ def create_kv_caches_with_random(
f
"Does not support key cache of type fp8 with head_size
{
head_size
}
"
f
"Does not support key cache of type fp8 with head_size
{
head_size
}
"
)
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
...
@@ -750,7 +764,7 @@ class CudaMemoryProfiler:
...
@@ -750,7 +764,7 @@ class CudaMemoryProfiler:
def
current_memory_usage
(
self
)
->
float
:
def
current_memory_usage
(
self
)
->
float
:
# Return the memory usage in bytes.
# Return the memory usage in bytes.
if
torch
.
cuda
.
is_availabl
e
():
if
current_platform
.
is_cuda_alik
e
():
torch
.
cuda
.
reset_peak_memory_stats
(
self
.
device
)
torch
.
cuda
.
reset_peak_memory_stats
(
self
.
device
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
self
.
device
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
self
.
device
)
elif
is_xpu
():
elif
is_xpu
():
...
...
vllm/worker/worker.py
View file @
6ffa3f31
...
@@ -454,14 +454,20 @@ def init_worker_distributed_environment(
...
@@ -454,14 +454,20 @@ def init_worker_distributed_environment(
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
if
torch_dtype
==
torch
.
bfloat16
:
# noqa: SIM102
compute_capability
=
current_platform
.
get
_device_capability
(
)
if
not
current_platform
.
has
_device_capability
(
80
):
if
comput
e_capability
[
0
]
<
8
:
capability
=
current_platform
.
get_devic
e_capability
()
gpu_name
=
current_platform
.
get_device_name
()
gpu_name
=
current_platform
.
get_device_name
()
if
capability
is
None
:
compute_str
=
"does not have a compute capability"
else
:
version_str
=
capability
.
as_version_str
()
compute_str
=
f
"has compute capability
{
version_str
}
"
raise
ValueError
(
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU
{
compute_str
}
. "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
. "
"You can use float16 instead by explicitly setting the"
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half."
)
"`dtype` flag in CLI, for example: --dtype=half."
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment