Unverified Commit 17299f08 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[misc] deep_gemm fallback to NVRTC when NVCC not found (#6252)

parent 5380cd7e
...@@ -15,6 +15,7 @@ _ENABLE_JIT_DEEPGEMM = False ...@@ -15,6 +15,7 @@ _ENABLE_JIT_DEEPGEMM = False
if is_cuda(): if is_cuda():
import deep_gemm import deep_gemm
from deep_gemm import get_num_sms from deep_gemm import get_num_sms
from deep_gemm.jit.compiler import get_nvcc_compiler
from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
from deep_gemm.jit_kernels.tuner import jit_tuner from deep_gemm.jit_kernels.tuner import jit_tuner
...@@ -48,7 +49,17 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv( ...@@ -48,7 +49,17 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases. # NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit # And NVCC JIT speed is also 9x faster in the ref commit
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0") _USE_NVRTC_DEFAULT = "0"
if _ENABLE_JIT_DEEPGEMM:
try:
get_nvcc_compiler()
except:
logger.warning(
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
"and may have performance loss with some cases."
)
_USE_NVRTC_DEFAULT = "1"
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment