Unverified Commit eebfdb94 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] fix potential bumpy throughtput with deepgemm (#5722)

parent dfb32264
...@@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup ...@@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
# Reduce warning # Reduce warning
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1" os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
# Force enable deep gemm # Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
......
...@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) ...@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var( _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true" "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
) )
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") _DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4) _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false") _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir # Force redirect deep_gemm cache_dir
os.environ["DG_CACHE_DIR"] = os.getenv( os.environ["DG_CACHE_DIR"] = os.getenv(
...@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv( ...@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST global _BUILTIN_M_LIST
global _DO_COMPILE global _DO_COMPILE_ALL
global _IS_FIRST_RANK_ON_NODE
# Generate m_max # Generate m_max
m_max = 1024 * 16 m_max = 1024 * 16
...@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): ...@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
m_max = min(1024 * 128, m_max) m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1)) _BUILTIN_M_LIST = list(range(1, m_max + 1))
# Check if is the first rank on node _IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
# Check if is the first rank on node.
# Default each rank will try compile all Ms to
# load all symbols at the launch stages.
# Avoid loading symbols at the serving stages.
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
class DeepGemmKernelType(IntEnum): class DeepGemmKernelType(IntEnum):
...@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic ...@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
def _compile_warning_1(): def _compile_warning_1():
if not _IN_PRE_COMPILE_STAGE: if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning( logger.warning(
"Entering DeepGEMM JIT Pre-Complie session. " "Entering DeepGEMM JIT Pre-Complie session. "
"And it may takes a long time(Typically 10-20 mins) " "And it may takes a long time(Typically 10-20 mins) "
...@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all( ...@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
query_key = (kernel_type, n, k, num_groups) query_key = (kernel_type, n, k, num_groups)
if ( if (
_ENABLE_JIT_DEEPGEMM_PRECOMPILE _ENABLE_JIT_DEEPGEMM_PRECOMPILE
and _DO_COMPILE and _DO_COMPILE_ALL
and _INITIALIZATION_DICT.get(query_key) is None and _INITIALIZATION_DICT.get(query_key) is None
): ):
_INITIALIZATION_DICT[query_key] = True _INITIALIZATION_DICT[query_key] = True
...@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all( ...@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
logger.info( logger.info(
f"Try DeepGEMM JIT Compiling for " f"Try DeepGEMM JIT Compiling for "
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}" f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
) )
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
...@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16( ...@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
@contextmanager @contextmanager
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
if _IN_PRE_COMPILE_STAGE: if _IN_PRECOMPILE_STAGE:
yield yield
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment