Unverified Commit acc816d8 authored by lukec's avatar lukec Committed by GitHub
Browse files

DeepEP normal support deepgemm-contiguous (#5626)


Co-authored-by: default avatarYingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: default avatarXuting Zhou <xutingz@nvidia.com>
Co-authored-by: default avatarZhengHSI <zhenghsi@qq.com>
parent a05bd83a
......@@ -5,16 +5,23 @@ import torch
import triton
import triton.language as tl
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import is_cuda
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
)
logger = logging.getLogger(__name__)
try:
from deep_gemm import ceil_div
except ImportError:
logger.error(f"Failed to import ceil_div from deep_gemm.")
import triton.language as tl
@triton.jit
......@@ -704,3 +711,334 @@ def grouped_gemm_triton(
**config,
)
return c
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0,
)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(
m_indices_start_ptr + start_m + off_expert,
cur_expert,
)
@triton.jit
def _fwd_kernel_ep_scatter_2(
total_token_num,
expert_start_loc,
recv_x,
recv_x_stride0,
recv_x_stride1,
recv_x_scale,
recv_x_scale_stride0,
recv_x_scale_stride1,
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
output_tensor_scale,
output_tensor_scale_stride0,
output_tensor_scale_stride1,
output_index,
output_index_stride0,
output_index_stride1,
topk_num: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num):
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
)
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
)
output_tensor_scale_ptr = (
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
BLOCK_E = 128 # token num of per expert is aligned to 128
BLOCK_D = 128 # block size of quantization
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@triton.jit
def _fwd_kernel_ep_gather(
total_token_num,
input_tensor,
input_tensor_stride0,
input_tensor_stride1,
recv_topk_ids,
recv_topk_ids_stride0,
recv_topk_ids_stride1,
recv_topk_weight,
recv_topk_weight_stride0,
recv_topk_weight_stride1,
input_index,
input_index_stride0,
input_index_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
topk_num: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block = tl.program_id(0)
start_cur_token = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num):
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if expert_id >= 0:
source_token_index = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
tmp = tl.load(
input_tensor
+ source_token_index * input_tensor_stride0
+ cur_block * BLOCK_D
+ off_d
)
accumulator += tmp.to(tl.float32) * acc_weight
tl.store(
output_tensor
+ cur_token * output_tensor_stride0
+ cur_block * BLOCK_D
+ off_d,
accumulator.to(output_tensor.dtype.element_ty),
)
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
output_tensor: torch.Tensor,
):
BLOCK_D = 1024 # block size of quantization
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
input_tensor.stride(0),
input_tensor.stride(1),
recv_topk_ids,
recv_topk_ids.stride(0),
recv_topk_ids.stride(1),
recv_topk_weight,
recv_topk_weight.stride(0),
recv_topk_weight.stride(1),
input_index,
input_index.stride(0),
input_index.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
topk_num=recv_topk_ids.shape[1],
num_warps=num_warps,
BLOCK_D=BLOCK_D,
)
return
# copy from
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
@triton.jit
def _tma_align_input_scale_kernel(
input_scale_ptr,
output_ptr,
m,
k_div_block_size,
input_scale_stride_m,
input_scale_stride_k,
output_stride_m,
output_stride_k,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
grid_m = tl.num_programs(0)
k_offsets = tl.arange(0, BLOCK_SIZE_K)
for m_base in range(pid_m, m, grid_m):
input_offset = (
input_scale_ptr
+ m_base * input_scale_stride_m
+ k_offsets * input_scale_stride_k
)
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
output_offset = (
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
)
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
def tma_align_input_scale(input_scale: torch.Tensor):
assert input_scale.dim() == 2
m, k_div_block_size = input_scale.shape
padd_m = get_tma_aligned_size(m, input_scale.element_size())
output = torch.empty(
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
)
grid_m = min(m, 8192)
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
_tma_align_input_scale_kernel[(grid_m,)](
input_scale_ptr=input_scale,
output_ptr=output,
m=m,
k_div_block_size=k_div_block_size,
input_scale_stride_m=input_scale.stride(0),
input_scale_stride_k=input_scale.stride(1),
output_stride_m=output.stride(1), # Note: these are swapped
output_stride_k=output.stride(0), # for column-major
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return output.t()[:m]
......@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
from sgl_kernel import silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deep_gemm = True
except ImportError:
......@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
post_reorder_triton_kernel,
......@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
......@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
def forward(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode,
):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
if _ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
......@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
)
return down_output
def forward_deepgemm_contiguous(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int],
):
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states_fp8.bfloat16()
M, K = hidden_states_fp8.size()
N = self.w13_weight.size(1)
scale_block_size = 128
gather_out = torch.empty_like(
hidden_states_fp8,
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype,
),
torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_idx)
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states_fp8,
hidden_states_scale,
topk_idx,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size
)
down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
down_output,
m_indices,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
return gather_out
def forward_deepgemm_masked(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
......
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import DeepEPMode
try:
from deep_ep import Buffer
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deepep = True
except ImportError:
use_deepep = False
from enum import IntEnum, auto
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -78,7 +83,6 @@ class DeepEPBuffer:
),
num_rdma_bytes,
)
cls._buffer = Buffer(
group,
num_nvl_bytes,
......@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights: torch.Tensor,
):
topk_idx = topk_idx.to(torch.int64)
if _ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
(
hidden_states,
topk_idx,
topk_weights,
event,
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
if _ENABLE_JIT_DEEPGEMM:
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, previous_event
)
else:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
event.current_stream_wait() if self.async_finish else ()
return (
hidden_states,
topk_idx,
topk_weights,
None,
num_recv_tokens_per_expert_list,
None,
None,
None,
)
seg_indptr = torch.zeros(
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
else:
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, previous_event
)
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
else:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
masked_m = expected_m = None
return (
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
seg_indptr,
masked_m,
expected_m,
)
masked_m = expected_m = None
return (
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
None,
seg_indptr,
masked_m,
expected_m,
)
def _dispatch_core(
self,
x: torch.Tensor,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
previous_event,
......@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x,
recv_topk_idx,
recv_topk_weights,
_, # num_recv_tokens_per_expert_list
num_recv_tokens_per_expert_list,
self.handle,
event,
) = buffer.dispatch(
......@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
)
return (
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
event,
)
......@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
if _ENABLE_JIT_DEEPGEMM:
output = hidden_states
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event
......@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_normal()
return DeepEPBuffer.get_deepep_buffer(
self.group,
self.hidden_size,
......@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx,
topk_weights,
reorder_topk_ids,
None,
seg_indptr,
masked_m,
expected_m,
......@@ -570,7 +611,8 @@ class DeepEPDispatcher:
def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs)
return self.dispatch_b()
ret = self.dispatch_b()
return ret
def dispatch_a(
self,
......@@ -593,7 +635,8 @@ class DeepEPDispatcher:
def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
return self.combine_b()
ret = self.combine_b()
return ret
def combine_a(
self,
......
......@@ -28,6 +28,11 @@ if is_cuda():
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
_ENABLE_JIT_DEEPGEMM = True
def get_enable_jit_deepgemm():
return _ENABLE_JIT_DEEPGEMM
logger = logging.getLogger(__name__)
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
......
......@@ -308,8 +308,8 @@ def sglang_per_token_group_quant_fp8(
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
if x.shape[0] > 0:
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
......
......@@ -357,6 +357,7 @@ class DeepseekV2MoE(nn.Module):
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
......@@ -368,10 +369,13 @@ class DeepseekV2MoE(nn.Module):
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
if self.ep_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment