"git@developer.sourcefind.cn:change/sglang.git" did not exist on "fa3c9e0668f16b9c5946dd621202f47324e71786"
Unverified Commit b4c41f72 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepGEMM integration (#7150)

parent 8b8f2e74
...@@ -4,6 +4,7 @@ from typing import List, Optional ...@@ -4,6 +4,7 @@ from typing import List, Optional
import torch import torch
import triton import triton
from sglang.math_utils import ceil_div
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import dispose_tensor, is_cuda from sglang.srt.utils import dispose_tensor, is_cuda
...@@ -15,11 +16,6 @@ if _is_cuda: ...@@ -15,11 +16,6 @@ if _is_cuda:
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8, sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
) )
try:
from deep_gemm import ceil_div
except ImportError:
logger.error(f"Failed to import ceil_div from deep_gemm.")
import triton.language as tl import triton.language as tl
......
import logging import logging
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import einops
import torch import torch
from sgl_kernel import silu_and_mul
from torch.nn import Module from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
from sgl_kernel import silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deep_gemm = True
except ImportError:
use_deep_gemm = False
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant, scaled_fp8_quant,
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8, sglang_per_token_quant_fp8,
) )
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs from sglang.srt.utils import (
DeepEPMode,
dispose_tensor,
get_bool_env_var,
is_hip,
set_weight_attrs,
)
_is_hip = is_hip() _is_hip = is_hip()
...@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
...@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE): ...@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
) )
self.deepep_mode = deepep_mode self.deepep_mode = deepep_mode
if self.deepep_mode.enable_low_latency(): if self.deepep_mode.enable_low_latency():
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm" assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = ( self.w13_weight_fp8 = (
self.w13_weight, self.w13_weight,
( (
...@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE): ...@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
): ):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
if _ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous( return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
) )
...@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE): ...@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
input_tensor[1] = tma_align_input_scale(input_tensor[1]) input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices input_tensor, self.w13_weight_fp8, gateup_output, m_indices
) )
del input_tensor del input_tensor
...@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE): ...@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
) )
del down_input del down_input
down_input_scale = tma_align_input_scale(down_input_scale) down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale), (down_input_fp8, down_input_scale),
self.w2_weight_fp8, self.w2_weight_fp8,
down_output, down_output,
...@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE): ...@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty( gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16 (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
) )
m_grouped_gemm_fp8_fp8_bf16_nt_masked( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m hidden_states_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
) )
dispose_tensor(hidden_states_fp8[0]) dispose_tensor(hidden_states_fp8[0])
...@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE): ...@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
n = self.w2_weight.size(1) n = self.w2_weight.size(1)
down_input_fp8 = ( down_input_fp8 = (
down_input, down_input,
get_col_major_tma_aligned_tensor(down_input_scale), deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
) )
down_output = torch.empty( down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16 (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
) )
m_grouped_gemm_fp8_fp8_bf16_nt_masked( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
) )
return down_output return down_output
......
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.expert_distribution import ( from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder, get_global_expert_distribution_recorder,
) )
...@@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
if _ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication # TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128) hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
if _ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
( (
hidden_states, hidden_states,
topk_idx, topk_idx,
...@@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
previous_event=previous_event, previous_event=previous_event,
async_finish=self.async_finish, async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
config=DeepEPConfig.get_instance().normal_dispatch_config, config=DeepEPConfig.get_instance().normal_dispatch_config,
) )
...@@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
if _ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
output = hidden_states output = hidden_states
else: else:
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
......
...@@ -5,33 +5,24 @@ from dataclasses import dataclass ...@@ -5,33 +5,24 @@ from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import torch
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
DEEPGEMM_V202506,
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda from sglang.srt.utils import get_bool_env_var, get_int_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ENABLE_JIT_DEEPGEMM = False
try: try:
import deep_gemm
from deep_gemm import get_num_sms from deep_gemm import get_num_sms
from deep_gemm.jit import build from deep_gemm.jit import build
from deep_gemm.jit.compiler import get_nvcc_compiler
from deep_gemm.jit_kernels.gemm import get_best_configs from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
sm_version = get_device_sm()
if sm_version == 90:
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
_ENABLE_JIT_DEEPGEMM = True
except ImportError: except ImportError:
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.") pass
def get_enable_jit_deepgemm():
return _ENABLE_JIT_DEEPGEMM
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
...@@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv( ...@@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
# NVRTC may have performance loss with some cases. # NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit # And NVCC JIT speed is also 9x faster in the ref commit
_USE_NVRTC_DEFAULT = "0" _USE_NVRTC_DEFAULT = "0"
if _ENABLE_JIT_DEEPGEMM: if ENABLE_JIT_DEEPGEMM:
try: try:
from deep_gemm.jit.compiler import get_nvcc_compiler
get_nvcc_compiler() get_nvcc_compiler()
except: except:
logger.warning( logger.warning(
...@@ -114,6 +107,7 @@ class DeepGemmKernelHelper: ...@@ -114,6 +107,7 @@ class DeepGemmKernelHelper:
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict() _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
# TODO improve naming
def _compile_warning_1(): def _compile_warning_1():
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE: if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning( logger.warning(
...@@ -127,6 +121,7 @@ def _compile_warning_1(): ...@@ -127,6 +121,7 @@ def _compile_warning_1():
) )
# TODO improve naming
def _compile_warning_2(): def _compile_warning_2():
logger.warning( logger.warning(
"Entering DeepGEMM JIT Single Kernel Compile session. " "Entering DeepGEMM JIT Single Kernel Compile session. "
...@@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one( ...@@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one(
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs) _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
# TODO further refactor warmup-related
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = { _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper( DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked", name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
...@@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all( ...@@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all(
num_groups: int, num_groups: int,
m_list: Optional[List[int]] = None, m_list: Optional[List[int]] = None,
) -> None: ) -> None:
global _INITIALIZATION_DICT global _INITIALIZATION_DICT
global _BUILTIN_M_LIST global _BUILTIN_M_LIST
...@@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all( ...@@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all(
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS) thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
):
num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(expected_m, n, k, kernel_type):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
lhs, rhs, out, masked_m, expected_m
)
def grouped_gemm_nt_f8f8bf16_contig(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
m_indices: torch.Tensor,
):
m, k = lhs[0].shape
num_groups, n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(m, n, k, kernel_type):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
with _log_jit_build(m, n, k, kernel_type):
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
@contextmanager @contextmanager
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
if _IN_PRECOMPILE_STAGE: if _IN_PRECOMPILE_STAGE:
...@@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): ...@@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
@contextmanager @contextmanager
def configure_deep_gemm_num_sms(num_sms): def deep_gemm_execution_hook(
if num_sms is None: m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
):
# not supported yet
if DEEPGEMM_V202506:
yield
return
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(m, n, k, kernel_type):
yield yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
import logging
from sglang.srt.utils import get_bool_env_var, get_device_sm
logger = logging.getLogger(__name__)
def _compute_enable_deep_gemm():
try:
import deep_gemm
except ImportError:
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
return False
sm_version = get_device_sm()
if sm_version < 90:
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_V202506 = False
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
import logging
from contextlib import contextmanager
from typing import Tuple
import torch
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
DEEPGEMM_SCALE_UE8M0,
DEEPGEMM_V202506,
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM:
import deep_gemm
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
recipe=None,
):
num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
with compile_utils.deep_gemm_execution_hook(
expected_m, n, k, num_groups, kernel_type
):
_grouped_gemm_nt_f8f8bf16_masked_raw(
lhs, rhs, out, masked_m, expected_m, recipe=recipe
)
def grouped_gemm_nt_f8f8bf16_contig(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
m_indices: torch.Tensor,
):
m, k = lhs[0].shape
num_groups, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
num_groups = 1
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
_gemm_nt_f8f8bf16_raw(
lhs,
rhs,
out,
)
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
compile_utils.update_deep_gemm_config(gpu_id, server_args)
@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None:
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
...@@ -23,7 +23,8 @@ import torch ...@@ -23,7 +23,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.math_utils import align
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_device_core_count, get_device_core_count,
...@@ -44,10 +45,6 @@ if _is_cuda: ...@@ -44,10 +45,6 @@ if _is_cuda:
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
) )
from sglang.srt.layers.quantization.deep_gemm import (
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -67,7 +64,6 @@ else: ...@@ -67,7 +64,6 @@ else:
fp8_max = torch.finfo(fp8_dtype).max fp8_max = torch.finfo(fp8_dtype).max
fp8_min = -fp8_max fp8_min = -fp8_max
if supports_custom_op(): if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt( def deep_gemm_fp8_fp8_bf16_nt(
...@@ -77,7 +73,7 @@ if supports_custom_op(): ...@@ -77,7 +73,7 @@ if supports_custom_op():
Bs: torch.Tensor, Bs: torch.Tensor,
C: torch.Tensor, C: torch.Tensor,
) -> None: ) -> None:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake( def deep_gemm_fp8_fp8_bf16_nt_fake(
A: torch.Tensor, A: torch.Tensor,
...@@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm( ...@@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm(
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype) M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
# Deepgemm only supports output tensor type as bfloat16 # Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
if supports_custom_op(): if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else: else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
return C return C
...@@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul( ...@@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul(
block_size: List[int], block_size: List[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return w8a8_block_fp8_matmul_deepgemm( return w8a8_block_fp8_matmul_deepgemm(
A, B, As, Bs, block_size, output_dtype=output_dtype A, B, As, Bs, block_size, output_dtype=output_dtype
) )
......
import os
from curses import flash
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import einops
import torch import torch
from sglang.math_utils import align from sglang.math_utils import align
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
...@@ -15,7 +15,6 @@ try: ...@@ -15,7 +15,6 @@ try:
except ImportError: except ImportError:
VLLM_AVAILABLE = False VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype, fp8_dtype,
fp8_max, fp8_max,
...@@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: ...@@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return cutlass_w8a8_block_fp8_linear_with_fallback return cutlass_w8a8_block_fp8_linear_with_fallback
elif _use_aiter: elif _use_aiter:
return aiter_w8a8_block_fp8_linear return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM: elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback return deepgemm_w8a8_block_fp8_linear_with_fallback
else: else:
return triton_w8a8_block_fp8_linear return triton_w8a8_block_fp8_linear
......
...@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union ...@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
...@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer from sglang.srt.layers.quantization import (
from sglang.srt.layers.quantization.deep_gemm import ( deep_gemm_wrapper,
_ENABLE_JIT_DEEPGEMM, monkey_patch_isinstance_for_vllm_base_layer,
update_deep_gemm_config,
) )
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
...@@ -205,8 +205,8 @@ class ModelRunner: ...@@ -205,8 +205,8 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
# Update deep gemm configure # Update deep gemm configure
if _ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
update_deep_gemm_config(gpu_id, server_args) deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model, tp_group can be different # If it is a draft model, tp_group can be different
self.initialize(min_per_gpu_memory) self.initialize(min_per_gpu_memory)
......
...@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
per_tensor_quant_mla_fp8, per_tensor_quant_mla_fp8,
...@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip ...@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.quantization.deep_gemm import (
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
)
else: else:
from vllm._custom_ops import awq_dequantize from vllm._custom_ops import awq_dequantize
...@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope.new_empty( q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank) (self.num_local_heads, aligned_m, self.kv_lora_rank)
) )
deep_gemm_grouped_gemm_nt_f8f8bf16_masked( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale), (q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k), (self.w_kc, self.w_scale_k),
q_nope_out, q_nope_out,
...@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and weight_block_size[1] == 128 and weight_block_size[1] == 128
and model_dtype == torch.bfloat16 and model_dtype == torch.bfloat16
): ):
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var( if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var(
"SGL_USE_DEEPGEMM_BMM", "false" "SGL_USE_DEEPGEMM_BMM", "false"
): ):
block_scale = weight_scale block_scale = weight_scale
......
...@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import ( ...@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
ScatterMode, ScatterMode,
) )
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
...@@ -479,7 +479,9 @@ def _model_forward_tbo( ...@@ -479,7 +479,9 @@ def _model_forward_tbo(
) )
del inputs del inputs
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms): with deep_gemm_wrapper.configure_deep_gemm_num_sms(
operations_strategy.deep_gemm_num_sms
):
outputs_arr = execute_overlapped_operations( outputs_arr = execute_overlapped_operations(
inputs_arr=inputs_arr, inputs_arr=inputs_arr,
operations_arr=[operations_strategy.operations] * 2, operations_arr=[operations_strategy.operations] * 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment