Unverified Commit 23010630 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

chore: upgrade sgl-kernel v0.1.2.post1 (#6196)


Co-authored-by: default avataralcanderian <alcanderian@gmail.com>
parent 45b4dcf0
...@@ -48,7 +48,7 @@ runtime_common = [ ...@@ -48,7 +48,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.1.1", "sgl-kernel==0.1.2.post1",
"flashinfer_python==0.2.5", "flashinfer_python==0.2.5",
"torch==2.6.0", "torch==2.6.0",
"torchvision==0.21.0", "torchvision==0.21.0",
......
...@@ -486,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -486,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda: if _is_cuda:
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.1.1", "0.1.2.post1",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
...@@ -16,11 +16,7 @@ if is_cuda(): ...@@ -16,11 +16,7 @@ if is_cuda():
import deep_gemm import deep_gemm
from deep_gemm import get_num_sms from deep_gemm import get_num_sms
from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
from deep_gemm.jit_kernels.m_grouped_gemm import (
template as deep_gemm_grouped_gemm_template,
)
from deep_gemm.jit_kernels.tuner import jit_tuner from deep_gemm.jit_kernels.tuner import jit_tuner
sm_version = get_device_sm() sm_version = get_device_sm()
...@@ -45,10 +41,15 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4) ...@@ -45,10 +41,15 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false") _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir # Force redirect deep_gemm cache_dir
os.environ["DG_CACHE_DIR"] = os.getenv( os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm" "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
) )
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST global _BUILTIN_M_LIST
...@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one( ...@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
num_groups: int, num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None: ) -> None:
# Auto-tuning with compilation num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
global deep_gemm_includes, deep_gemm_grouped_gemm_template block_k = 128
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config num_tma_threads = 128
_ = jit_tuner.compile_and_tune( num_math_threads_per_group = 128
kwargs = {
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"BLOCK_K": block_k,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt", name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={ keys={
"N": n, "N": n,
...@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one( ...@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
"NUM_STAGES": num_stages, "NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0], "NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": "GroupedMasked", "GEMM_TYPE": GemmType.GroupedMasked,
}, },
space=(), space=(),
includes=deep_gemm_includes, kwargs=kwargs,
arg_defs=( runtime_cls=FP8GemmRuntime,
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("grouped_layout", torch.int32),
("m", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_grouped_gemm_template,
args=[],
) )
...@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one( ...@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
num_groups: int, num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None: ) -> None:
global deep_gemm_includes, deep_gemm_grouped_gemm_template num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config block_k = 128
_ = jit_tuner.compile_and_tune( num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"BLOCK_K": block_k,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt", name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={ keys={
"N": n, "N": n,
...@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one( ...@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
"NUM_STAGES": num_stages, "NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0], "NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": "GroupedContiguous", "GEMM_TYPE": GemmType.GroupedContiguous,
}, },
space=(), space=(),
includes=deep_gemm_includes, kwargs=kwargs,
arg_defs=( runtime_cls=FP8GemmRuntime,
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("grouped_layout", torch.int32),
("m", int),
("num_groups", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_grouped_gemm_template,
args=[],
) )
...@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one( ...@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
_: int, # _ is a dummy parameter to align with other interfaces _: int, # _ is a dummy parameter to align with other interfaces
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]], config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None: ) -> None:
global deep_gemm_includes, deep_gemm_gemm_template num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config block_k = 128
_ = jit_tuner.compile_and_tune( num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.Normal,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"NUM_GROUPS": 1,
"BLOCK_K": block_k,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="gemm_fp8_fp8_bf16_nt", name="gemm_fp8_fp8_bf16_nt",
keys={ keys={
"N": n, "N": n,
...@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one( ...@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
}, },
space=(), space=(),
includes=deep_gemm_includes, kwargs=kwargs,
arg_defs=( runtime_cls=FP8GemmRuntime,
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("m", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_gemm_template,
args=[],
) )
...@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): ...@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
from deep_gemm.jit.runtime import RuntimeCache from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__ origin_func = RuntimeCache.get
def __patched_func(self, *args, **kwargs): def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs) ret = origin_func(self, *args, **kwargs)
...@@ -385,6 +375,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): ...@@ -385,6 +375,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
) )
return ret return ret
RuntimeCache.__getitem__ = __patched_func RuntimeCache.get = __patched_func
yield yield
RuntimeCache.__getitem__ = origin_func RuntimeCache.get = origin_func
...@@ -16,7 +16,7 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel* ...@@ -16,7 +16,7 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
pip install --upgrade pip pip install --upgrade pip
# Install sgl-kernel # Install sgl-kernel
pip install sgl-kernel==0.1.1 --no-cache-dir pip install sgl-kernel==0.1.2.post1 --no-cache-dir
# Install the main package # Install the main package
pip install -e "python[all]" pip install -e "python[all]"
......
...@@ -34,7 +34,7 @@ rm -rf /usr/local/include/nvshmem* ...@@ -34,7 +34,7 @@ rm -rf /usr/local/include/nvshmem*
pip install --upgrade pip pip install --upgrade pip
# Install sgl-kernel # Install sgl-kernel
pip install sgl-kernel==0.1.1 --no-cache-dir pip install sgl-kernel==0.1.2.post1 --no-cache-dir
# Install the main package # Install the main package
pip install -e "python[all]" pip install -e "python[all]"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment