Unverified Commit 6153f2ff authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

chore: upgrade sgl-kernel v0.1.6 (#6945)

parent 8b5f83ed
......@@ -49,7 +49,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.1.5",
"sgl-kernel==0.1.6",
"flashinfer_python==0.2.5",
"torch==2.6.0",
"torchvision==0.21.0",
......
......@@ -579,7 +579,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda:
assert_pkg_version(
"sgl-kernel",
"0.1.5",
"0.1.6",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
......
......@@ -17,10 +17,10 @@ _ENABLE_JIT_DEEPGEMM = False
try:
import deep_gemm
from deep_gemm import get_num_sms
from deep_gemm.jit import build
from deep_gemm.jit.compiler import get_nvcc_compiler
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
from deep_gemm.jit_kernels.tuner import jit_tuner
sm_version = get_device_sm()
if sm_version == 90:
......@@ -148,32 +148,28 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.GroupedMasked,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": GemmType.GroupedMasked,
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
......@@ -187,31 +183,26 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.GroupedContiguous,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": GemmType.GroupedContiguous,
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
def _compile_gemm_nt_f8f8bf16_one(
......@@ -228,28 +219,23 @@ def _compile_gemm_nt_f8f8bf16_one(
"GEMM_TYPE": GemmType.Normal,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
code = FP8GemmRuntime.generate(kwargs)
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment