Unverified Commit 6ec5e9fd authored by SherryC41's avatar SherryC41 Committed by GitHub
Browse files

refactor: abstract deepgemm support into platform (#37519)


Co-authored-by: default avatarsherryC41 <sherry.c.c41@gmail.com>
parent e1d85e5c
......@@ -511,6 +511,11 @@ class CudaPlatformBase(Platform):
def support_static_graph_mode(cls) -> bool:
return True
@classmethod
def support_deep_gemm(cls) -> bool:
"""Currently, only Hopper and Blackwell GPUs are supported."""
return cls.is_device_capability(90) or cls.is_device_capability_family(100)
@classmethod
def num_compute_units(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count
......
......@@ -712,6 +712,13 @@ class Platform:
"""
return False
@classmethod
def support_deep_gemm(cls) -> bool:
"""
Returns if DeepGEMM is supported by the current platform.
"""
return False
@classmethod
def use_custom_op_collectives(cls) -> bool:
"""
......
......@@ -70,10 +70,7 @@ def is_deep_gemm_supported() -> bool:
"""Return `True` if DeepGEMM is supported on the current platform.
Currently, only Hopper and Blackwell GPUs are supported.
"""
is_supported_arch = current_platform.is_cuda() and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
is_supported_arch = current_platform.support_deep_gemm()
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment