"examples/sampling/graphbolt/vscode:/vscode.git/clone" did not exist on "be0bf495d7f18b61e11b6ae391a936c56956e395"
Unverified Commit eebfdb94 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] fix potential bumpy throughtput with deepgemm (#5722)

parent dfb32264
......@@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True)
# Reduce warning
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
# Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
......
......@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
)
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir
os.environ["DG_CACHE_DIR"] = os.getenv(
......@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST
global _DO_COMPILE
global _DO_COMPILE_ALL
global _IS_FIRST_RANK_ON_NODE
# Generate m_max
m_max = 1024 * 16
......@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1))
# Check if is the first rank on node
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
# Check if is the first rank on node.
# Default each rank will try compile all Ms to
# load all symbols at the launch stages.
# Avoid loading symbols at the serving stages.
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
class DeepGemmKernelType(IntEnum):
......@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
def _compile_warning_1():
if not _IN_PRE_COMPILE_STAGE:
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
"Entering DeepGEMM JIT Pre-Complie session. "
"And it may takes a long time(Typically 10-20 mins) "
......@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
query_key = (kernel_type, n, k, num_groups)
if (
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
and _DO_COMPILE
and _DO_COMPILE_ALL
and _INITIALIZATION_DICT.get(query_key) is None
):
_INITIALIZATION_DICT[query_key] = True
......@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
logger.info(
f"Try DeepGEMM JIT Compiling for "
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
)
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
......@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
@contextmanager
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
if _IN_PRE_COMPILE_STAGE:
if _IN_PRECOMPILE_STAGE:
yield
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment