Unverified Commit b2435be6 authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

Cache the result of `is_blackwell` platform check (#10498)

parent 5fe39e85
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import torch import torch
from sglang.srt.utils import get_bool_env_var, get_device_sm from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -21,12 +21,7 @@ def _compute_enable_deep_gemm(): ...@@ -21,12 +21,7 @@ def _compute_enable_deep_gemm():
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
def _is_blackwell_arch() -> bool:
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
return major == 10
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch() DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
...@@ -167,6 +167,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8) ...@@ -167,6 +167,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9) is_hopper_with_cuda_12_3 = lambda: _check(9)
@lru_cache(maxsize=1)
def is_blackwell(): def is_blackwell():
if not is_cuda(): if not is_cuda():
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment