Unverified Commit 4c56e5db authored by lukec's avatar lukec Committed by GitHub
Browse files

Set deepgemm to the default value in the hopper architecture. (#4613)

parent 7b5fc719
...@@ -26,11 +26,14 @@ from sglang.srt.utils import ( ...@@ -26,11 +26,14 @@ from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_device_core_count, get_device_core_count,
get_device_name, get_device_name,
get_device_sm,
is_cuda, is_cuda,
is_hip, is_hip,
supports_custom_op, supports_custom_op,
) )
_enable_jit_deepgemm = False
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
...@@ -39,9 +42,12 @@ if _is_cuda: ...@@ -39,9 +42,12 @@ if _is_cuda:
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"` import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
logger = logging.getLogger(__name__) sm_version = get_device_sm()
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
_enable_jit_deepgemm = True
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
logger = logging.getLogger(__name__)
if supports_custom_op(): if supports_custom_op():
...@@ -771,7 +777,7 @@ def w8a8_block_fp8_matmul( ...@@ -771,7 +777,7 @@ def w8a8_block_fp8_matmul(
) )
# deepgemm only support bf16 # deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
if supports_custom_op(): if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else: else:
......
...@@ -1006,6 +1006,13 @@ def get_amdgpu_memory_capacity(): ...@@ -1006,6 +1006,13 @@ def get_amdgpu_memory_capacity():
) )
def get_device_sm():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor
return 0
def get_nvgpu_memory_capacity(): def get_nvgpu_memory_capacity():
try: try:
# Run nvidia-smi and capture the output # Run nvidia-smi and capture the output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment