Unverified Commit aa3eba8e authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarfzyzcjy <ch271828n@outlook.com>
parent 07ee0ab7
...@@ -38,6 +38,8 @@ jobs: ...@@ -38,6 +38,8 @@ jobs:
include: include:
- python-version: "3.10" - python-version: "3.10"
cuda-version: "12.4" cuda-version: "12.4"
- python-version: "3.10"
cuda-version: "12.8"
- python-version: "3.10" - python-version: "3.10"
cuda-version: "12.9" cuda-version: "12.9"
name: Build Wheel (CUDA ${{ matrix.cuda-version }}) name: Build Wheel (CUDA ${{ matrix.cuda-version }})
......
...@@ -248,7 +248,6 @@ class EPMoE(FusedMoE): ...@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
) )
del gateup_input del gateup_input
del gateup_input_fp8 del gateup_input_fp8
...@@ -304,7 +303,6 @@ class EPMoE(FusedMoE): ...@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
down_output, down_output,
masked_m, masked_m,
expected_m, expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
) )
del down_input del down_input
del down_input_fp8 del down_input_fp8
...@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE): ...@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
) )
dispose_tensor(hidden_states_fp8[0]) dispose_tensor(hidden_states_fp8[0])
...@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE): ...@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
( (
down_input_scale down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
down_input_scale
)
), ),
) )
down_output = torch.empty( down_output = torch.empty(
...@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE): ...@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
down_output, down_output,
masked_m, masked_m,
expected_m, expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
) )
return down_output return down_output
......
import logging import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple from typing import Dict, List, Tuple
from tqdm.contrib.concurrent import thread_map import torch
from tqdm import tqdm
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
DEEPGEMM_BLACKWELL,
ENABLE_JIT_DEEPGEMM, ENABLE_JIT_DEEPGEMM,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, get_int_env_var from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL: if ENABLE_JIT_DEEPGEMM:
from deep_gemm import get_num_sms import deep_gemm
from deep_gemm.jit import build
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
...@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv( ...@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases. # NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit # And NVCC JIT speed is also 9x faster in the ref commit
_USE_NVRTC_DEFAULT = "0" os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
if ENABLE_JIT_DEEPGEMM:
try:
from deep_gemm.jit.compiler import get_nvcc_compiler
get_nvcc_compiler()
except:
logger.warning(
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
"and may have performance loss with some cases."
)
_USE_NVRTC_DEFAULT = "1"
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
...@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): ...@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
# Default each rank will try compile all Ms to # Default each rank will try compile all Ms to
# load all symbols at the launch stages. # load all symbols at the launch stages.
# Avoid loading symbols at the serving stages. # Avoid loading symbols at the serving stages.
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
class DeepGemmKernelType(IntEnum): class DeepGemmKernelType(IntEnum):
...@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum): ...@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
GEMM_NT_F8F8BF16 = auto() GEMM_NT_F8F8BF16 = auto()
@dataclass
class DeepGemmKernelHelper:
name: str
compile_func: Callable[
[
int,
int,
int,
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
],
None,
]
configure_func: Callable[
[int, int, int, int, int],
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
]
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict() _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
# TODO improve naming # TODO improve code
def _compile_warning_1():
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
"Entering DeepGEMM JIT Pre-Compile session. "
"It may takes a long time (typically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
# TODO improve naming
def _compile_warning_2():
logger.warning(
"Entering DeepGEMM JIT Single Kernel Compile session. "
"And it will makes inference throughput becomes flaky. "
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to solve this issue. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
n: int,
k: int,
num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.GroupedMasked,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": num_groups,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
n: int,
k: int,
num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.GroupedContiguous,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
def _compile_gemm_nt_f8f8bf16_one(
n: int,
k: int,
_: int, # _ is a dummy parameter to align with other interfaces
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.Normal,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
code = FP8GemmRuntime.generate(kwargs)
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
# TODO further refactor warmup-related
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
m, n, k, num_groups, num_sms, is_grouped_masked=True
),
),
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
m, n, k, 1, num_sms, is_grouped_contiguous=True
),
),
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
name="gemm_fp8_fp8_bf16_nt",
compile_func=_compile_gemm_nt_f8f8bf16_one,
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
m, n, k, 1, num_sms
),
),
}
def _maybe_compile_deep_gemm_one_type_all( def _maybe_compile_deep_gemm_one_type_all(
kernel_type: DeepGemmKernelType, kernel_type: DeepGemmKernelType,
n: int, n: int,
k: int, k: int,
num_groups: int, num_groups: int,
m_list: Optional[List[int]] = None,
) -> None: ) -> None:
global _INITIALIZATION_DICT global _INITIALIZATION_DICT
global _BUILTIN_M_LIST global _BUILTIN_M_LIST
...@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all( ...@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all(
): ):
_INITIALIZATION_DICT[query_key] = True _INITIALIZATION_DICT[query_key] = True
kernel_helper = _KERNEL_HELPER_DICT[kernel_type] # TODO maybe improve logs
_compile_warning_1() if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
"Entering DeepGEMM JIT Pre-Compile session. "
"It may takes a long time (typically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
logger.info( logger.info(
f"Try DeepGEMM JIT Compiling for " f"Try DeepGEMM JIT Compiling for "
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}" f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
) )
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced _compile_deep_gemm_one_type_all(
num_sms = get_num_sms() kernel_type=kernel_type,
collected_configs = set() n=n,
for m in m_list if m_list is not None else _BUILTIN_M_LIST: k=k,
# Put config into set to get unique configs and reduce cases to be compiled num_groups=num_groups,
collected_configs.add( m_list=_BUILTIN_M_LIST,
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
)
compile_func = lambda config: kernel_helper.compile_func(
n, k, num_groups, config
) )
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
@contextmanager # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): def _compile_deep_gemm_one_type_all(
if _IN_PRECOMPILE_STAGE: kernel_type: DeepGemmKernelType,
yield n: int,
return k: int,
num_groups: int,
m_list: List[int],
) -> None:
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
from deep_gemm.jit.runtime import RuntimeCache executor = _BaseWarmupExecutor.create(
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
)
origin_func = RuntimeCache.get # TODO can use multi thread
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
executor.execute(m=m)
def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
if not DEEPGEMM_BLACKWELL:
_compile_warning_2()
logger.warning(
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
)
return ret
RuntimeCache.get = __patched_func class _BaseWarmupExecutor:
yield @staticmethod
RuntimeCache.get = origin_func def create(kernel_type: DeepGemmKernelType, **kwargs):
return {
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
}[kernel_type](**kwargs)
def execute(self, m):
raise NotImplementedError
def _empty_token_fp8(size):
*dims, k = size
return (
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
torch.empty(
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
),
)
def _empty_block_fp8(size):
*dims, n, k = size
return (
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
torch.empty(
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
device="cuda",
dtype=torch.float32,
),
)
_BLOCK_SIZE = 128
class _NormalWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
def execute(self, m):
deep_gemm.fp8_gemm_nt(
(self.lhs_q[:m], self.lhs_s[:m]),
(self.rhs_q, self.rhs_s),
self.out[:m],
)
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
def execute(self, m):
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
(self.lhs_q[:m], self.lhs_s[:m]),
(self.rhs_q, self.rhs_s),
self.out[:m],
m_indices=self.m_indices[:m],
)
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
self.out = torch.empty(
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
)
def execute(self, m):
deep_gemm.fp8_m_grouped_gemm_nt_masked(
(self.lhs_q, self.lhs_s),
(self.rhs_q, self.rhs_s),
self.out,
masked_m=self.masked_m,
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
expected_m=m,
)
@contextmanager @contextmanager
def deep_gemm_execution_hook( def deep_gemm_execution_hook(
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
): ):
# not supported yet _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
if not DEEPGEMM_BLACKWELL: yield
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(m, n, k, kernel_type):
yield
...@@ -24,14 +24,12 @@ def _compute_enable_deep_gemm(): ...@@ -24,14 +24,12 @@ def _compute_enable_deep_gemm():
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() def _is_blackwell_arch() -> bool:
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
return major == 10
try:
from deep_gemm import fp8_gemm_nt
# They have not given a name to this breaking change ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_BLACKWELL = True
except ImportError:
DEEPGEMM_BLACKWELL = False
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
...@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__) ...@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM: if ENABLE_JIT_DEEPGEMM:
import deep_gemm import deep_gemm
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
if DEEPGEMM_BLACKWELL:
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import (
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
from deep_gemm import (
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
else:
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
# TODO maybe rename these functions
def grouped_gemm_nt_f8f8bf16_masked( def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor], lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, out: torch.Tensor,
masked_m: torch.Tensor, masked_m: torch.Tensor,
expected_m: int, expected_m: int,
recipe=None,
): ):
num_groups, _, k = lhs[0].shape num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape _, n, _ = rhs[0].shape
...@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked( ...@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked(
with compile_utils.deep_gemm_execution_hook( with compile_utils.deep_gemm_execution_hook(
expected_m, n, k, num_groups, kernel_type expected_m, n, k, num_groups, kernel_type
): ):
_grouped_gemm_nt_f8f8bf16_masked_raw( deep_gemm.fp8_m_grouped_gemm_nt_masked(
lhs, lhs,
rhs, rhs,
out, out,
masked_m, masked_m,
expected_m, expected_m,
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
) )
...@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig( ...@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices) deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16( def gemm_nt_f8f8bf16(
...@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16( ...@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16(
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16 kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
_gemm_nt_f8f8bf16_raw( deep_gemm.fp8_gemm_nt(
lhs, lhs,
rhs, rhs,
out, out,
......
...@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw( ...@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
) )
if scale_ue8m0: if scale_ue8m0:
from deep_gemm.utils.layout import transform_sf_into_required_layout from deep_gemm import transform_sf_into_required_layout
assert group_size == 128 assert group_size == 128
x_s = transform_sf_into_required_layout( x_s = transform_sf_into_required_layout(
...@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul( ...@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
# scale_ue8m0=scale_ue8m0, # scale_ue8m0=scale_ue8m0,
# ) # )
from deep_gemm.utils.layout import transform_sf_into_required_layout from deep_gemm import transform_sf_into_required_layout
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
......
...@@ -459,7 +459,7 @@ def _requant_weight_ue8m0( ...@@ -459,7 +459,7 @@ def _requant_weight_ue8m0(
import deep_gemm.utils.layout import deep_gemm.utils.layout
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf) sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf return sf
out_s = _transform_scale(out_s, mn=out_w.shape[-2]) out_s = _transform_scale(out_s, mn=out_w.shape[-2])
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
import torch import torch
...@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil: ...@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
@classmethod @classmethod
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple: def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported. """Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
Args: Args:
input (torch.Tensor): The input tensor to be quantized. input (torch.Tensor): The input tensor to be quantized.
......
...@@ -50,25 +50,17 @@ FetchContent_Declare( ...@@ -50,25 +50,17 @@ FetchContent_Declare(
) )
FetchContent_Populate(repo-cutlass) FetchContent_Populate(repo-cutlass)
# DeepGEMM FetchContent_Declare(
if("${CUDA_VERSION}" VERSION_EQUAL "12.8") repo-fmt
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") GIT_REPOSITORY https://github.com/fmtlib/fmt
set(DeepGEMM_TAG "blackwell") GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9") GIT_SHALLOW OFF
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") )
set(DeepGEMM_TAG "blackwell")
elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0")
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
set(DeepGEMM_TAG "blackwell")
else()
set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM")
set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0")
endif()
FetchContent_Declare( FetchContent_Declare(
repo-deepgemm repo-deepgemm
GIT_REPOSITORY ${DeepGEMM_REPO} GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
GIT_TAG ${DeepGEMM_TAG} GIT_TAG sgl
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-deepgemm) FetchContent_Populate(repo-deepgemm)
...@@ -86,7 +78,7 @@ FetchContent_Populate(repo-triton) ...@@ -86,7 +78,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare( FetchContent_Declare(
repo-flashinfer repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flashinfer) FetchContent_Populate(repo-flashinfer)
...@@ -182,28 +174,11 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) ...@@ -182,28 +174,11 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100" "-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a" "-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_103,code=sm_103" "-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_103a,code=sm_103a" "-gencode=arch=compute_101a,code=sm_101a"
"-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a" "-gencode=arch=compute_120a,code=sm_120a"
) )
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_110,code=sm_110"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121,code=sm_121"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
)
endif()
else() else()
list(APPEND SGL_KERNEL_CUDA_FLAGS list(APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math" "-use_fast_math"
...@@ -286,6 +261,12 @@ set(SOURCES ...@@ -286,6 +261,12 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/moe_topk_softmax_kernels.cu"
...@@ -321,8 +302,6 @@ target_include_directories(common_ops PRIVATE ...@@ -321,8 +302,6 @@ target_include_directories(common_ops PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/common ${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
) )
set_source_files_properties("csrc/gemm/per_token_group_quant_8bit" PROPERTIES COMPILE_OPTIONS "--use_fast_math")
find_package(Python3 COMPONENTS Interpreter REQUIRED) find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process( execute_process(
...@@ -464,13 +443,38 @@ install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) ...@@ -464,13 +443,38 @@ install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
set(DEEPGEMM_SOURCES set(DEEPGEMM_SOURCES
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp" "${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
) )
# JIT Logic
# DeepGEMM
install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/" Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
DESTINATION "deep_gemm"
PATTERN ".git*" EXCLUDE # Link against necessary libraries, including nvrtc for JIT compilation.
PATTERN "__pycache__" EXCLUDE) target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
# Add include directories needed by DeepGEMM.
target_include_directories(deep_gemm_cpp PRIVATE
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
${repo-cutlass_SOURCE_DIR}/include
${repo-fmt_SOURCE_DIR}/include
)
# Apply the same compile options as common_ops.
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
# Create an empty __init__.py to make `deepgemm` a Python package.
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
DESTINATION deep_gemm
RENAME __init__.py
)
# Install the compiled DeepGEMM API library.
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
# Install the source files required by DeepGEMM for runtime JIT compilation.
install(
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
DESTINATION deep_gemm
)
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/" install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
DESTINATION "deep_gemm/include/cute") DESTINATION "deep_gemm/include/cute")
......
...@@ -9,7 +9,6 @@ import jinja2 ...@@ -9,7 +9,6 @@ import jinja2
FILE_HEAD = """ FILE_HEAD = """
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
...@@ -34,17 +33,6 @@ TEMPLATE = ( ...@@ -34,17 +33,6 @@ TEMPLATE = (
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
KERNEL_FILE_TEMPLATE = (
"// auto generated by generate.py\n"
"// clang-format off\n"
"#pragma once\n\n"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"\n'
"{% endfor %}"
)
KERNEL_FILE_NAME = "kernel_marlin.cuh"
# int8 with zero point case (sglang::kU8) is also supported, # int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"] SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
...@@ -60,12 +48,11 @@ DTYPES = ["fp16", "bf16"] ...@@ -60,12 +48,11 @@ DTYPES = ["fp16", "bf16"]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"): for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
kernel_files = set()
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type has_zp = "B" not in scalar_type
all_template_str_list = [] all_template_str_list = []
...@@ -108,20 +95,10 @@ def generate_new_kernels(): ...@@ -108,20 +95,10 @@ def generate_new_kernels():
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh" filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
kernel_files.add(filename)
kernel_files = list(kernel_files)
kernel_files.sort()
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
kernel_files=kernel_files
)
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
f.write(file_content)
if __name__ == "__main__": if __name__ == "__main__":
......
#pragma once
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel_bf16_ku4.cuh"
#include "kernel_bf16_ku4b8.cuh"
#include "kernel_bf16_ku8b128.cuh"
#include "kernel_fp16_ku4.cuh"
#include "kernel_fp16_ku4b8.cuh"
#include "kernel_fp16_ku8b128.cuh"
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
/* /*
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#pragma once
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#endif #endif
#include "kernel.h" #include "kernel.h"
#include "kernel_marlin.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \ static_assert( \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment