Unverified Commit b4c41f72 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepGEMM integration (#7150)

parent 8b8f2e74
......@@ -4,6 +4,7 @@ from typing import List, Optional
import torch
import triton
from sglang.math_utils import ceil_div
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import dispose_tensor, is_cuda
......@@ -15,11 +16,6 @@ if _is_cuda:
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
)
try:
from deep_gemm import ceil_div
except ImportError:
logger.error(f"Failed to import ceil_div from deep_gemm.")
import triton.language as tl
......
import logging
from typing import Callable, List, Optional, Tuple
import einops
import torch
from sgl_kernel import silu_and_mul
from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
from sgl_kernel import silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deep_gemm = True
except ImportError:
use_deep_gemm = False
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
......@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
)
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
from sglang.srt.utils import (
DeepEPMode,
dispose_tensor,
get_bool_env_var,
is_hip,
set_weight_attrs,
)
_is_hip = is_hip()
......@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
......@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
)
self.deepep_mode = deepep_mode
if self.deepep_mode.enable_low_latency():
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
(
......@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
if _ENABLE_JIT_DEEPGEMM:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
......@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
......@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
)
del down_input
down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
down_output,
......@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
hidden_states_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
)
dispose_tensor(hidden_states_fp8[0])
......@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
)
return down_output
......
import logging
from dataclasses import dataclass
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
......@@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights: torch.Tensor,
):
topk_idx = topk_idx.to(torch.int64)
if _ENABLE_JIT_DEEPGEMM:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
if _ENABLE_JIT_DEEPGEMM:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
(
hidden_states,
topk_idx,
......@@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
......@@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
if _ENABLE_JIT_DEEPGEMM:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
output = hidden_states
else:
if hidden_states.shape[0] > 0:
......
......@@ -5,33 +5,24 @@ from dataclasses import dataclass
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple
import torch
from tqdm.contrib.concurrent import thread_map
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
DEEPGEMM_V202506,
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
from sglang.srt.utils import get_bool_env_var, get_int_env_var
logger = logging.getLogger(__name__)
_ENABLE_JIT_DEEPGEMM = False
try:
import deep_gemm
from deep_gemm import get_num_sms
from deep_gemm.jit import build
from deep_gemm.jit.compiler import get_nvcc_compiler
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
sm_version = get_device_sm()
if sm_version == 90:
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
_ENABLE_JIT_DEEPGEMM = True
except ImportError:
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
def get_enable_jit_deepgemm():
return _ENABLE_JIT_DEEPGEMM
pass
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
......@@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
_USE_NVRTC_DEFAULT = "0"
if _ENABLE_JIT_DEEPGEMM:
if ENABLE_JIT_DEEPGEMM:
try:
from deep_gemm.jit.compiler import get_nvcc_compiler
get_nvcc_compiler()
except:
logger.warning(
......@@ -114,6 +107,7 @@ class DeepGemmKernelHelper:
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
# TODO improve naming
def _compile_warning_1():
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
......@@ -127,6 +121,7 @@ def _compile_warning_1():
)
# TODO improve naming
def _compile_warning_2():
logger.warning(
"Entering DeepGEMM JIT Single Kernel Compile session. "
......@@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one(
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
# TODO further refactor warmup-related
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
......@@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all(
num_groups: int,
m_list: Optional[List[int]] = None,
) -> None:
global _INITIALIZATION_DICT
global _BUILTIN_M_LIST
......@@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all(
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
):
num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(expected_m, n, k, kernel_type):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
lhs, rhs, out, masked_m, expected_m
)
def grouped_gemm_nt_f8f8bf16_contig(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
m_indices: torch.Tensor,
):
m, k = lhs[0].shape
num_groups, n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(m, n, k, kernel_type):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
with _log_jit_build(m, n, k, kernel_type):
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
@contextmanager
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
if _IN_PRECOMPILE_STAGE:
......@@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None:
def deep_gemm_execution_hook(
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
):
# not supported yet
if DEEPGEMM_V202506:
yield
return
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(m, n, k, kernel_type):
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
import logging
from sglang.srt.utils import get_bool_env_var, get_device_sm
logger = logging.getLogger(__name__)
def _compute_enable_deep_gemm():
try:
import deep_gemm
except ImportError:
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
return False
sm_version = get_device_sm()
if sm_version < 90:
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_V202506 = False
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
import logging
from contextlib import contextmanager
from typing import Tuple
import torch
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
DEEPGEMM_SCALE_UE8M0,
DEEPGEMM_V202506,
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM:
import deep_gemm
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
)
from deep_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
)
def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
recipe=None,
):
num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
with compile_utils.deep_gemm_execution_hook(
expected_m, n, k, num_groups, kernel_type
):
_grouped_gemm_nt_f8f8bf16_masked_raw(
lhs, rhs, out, masked_m, expected_m, recipe=recipe
)
def grouped_gemm_nt_f8f8bf16_contig(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
m_indices: torch.Tensor,
):
m, k = lhs[0].shape
num_groups, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
num_groups = 1
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
_gemm_nt_f8f8bf16_raw(
lhs,
rhs,
out,
)
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
compile_utils.update_deep_gemm_config(gpu_id, server_args)
@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None:
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
......@@ -23,7 +23,8 @@ import torch
import triton
import triton.language as tl
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.math_utils import align
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import (
direct_register_custom_op,
get_device_core_count,
......@@ -44,10 +45,6 @@ if _is_cuda:
sgl_per_token_quant_fp8,
)
from sglang.srt.layers.quantization.deep_gemm import (
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
)
logger = logging.getLogger(__name__)
......@@ -67,7 +64,6 @@ else:
fp8_max = torch.finfo(fp8_dtype).max
fp8_min = -fp8_max
if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt(
......@@ -77,7 +73,7 @@ if supports_custom_op():
Bs: torch.Tensor,
C: torch.Tensor,
) -> None:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake(
A: torch.Tensor,
......@@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm(
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
return C
......@@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul(
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return w8a8_block_fp8_matmul_deepgemm(
A, B, As, Bs, block_size, output_dtype=output_dtype
)
......
import os
from curses import flash
from typing import Callable, List, Optional, Tuple
import einops
import torch
from sglang.math_utils import align
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.utils import is_sm100_supported
......@@ -15,7 +15,6 @@ try:
except ImportError:
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
fp8_max,
......@@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _use_aiter:
return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM:
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback
else:
return triton_w8a8_block_fp8_linear
......
......@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
......@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
from sglang.srt.layers.quantization.deep_gemm import (
_ENABLE_JIT_DEEPGEMM,
update_deep_gemm_config,
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
)
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
......@@ -205,8 +205,8 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed()
# Update deep gemm configure
if _ENABLE_JIT_DEEPGEMM:
update_deep_gemm_config(gpu_id, server_args)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model, tp_group can be different
self.initialize(min_per_gpu_memory)
......
......@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
......@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.quantization.deep_gemm import (
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
)
else:
from vllm._custom_ops import awq_dequantize
......@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
......@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var(
"SGL_USE_DEEPGEMM_BMM", "false"
):
block_scale = weight_scale
......
......@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
ScatterMode,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
......@@ -479,7 +479,9 @@ def _model_forward_tbo(
)
del inputs
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
operations_strategy.deep_gemm_num_sms
):
outputs_arr = execute_overlapped_operations(
inputs_arr=inputs_arr,
operations_arr=[operations_strategy.operations] * 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment