Unverified Commit acc816d8 authored by lukec's avatar lukec Committed by GitHub
Browse files

DeepEP normal support deepgemm-contiguous (#5626)


Co-authored-by: default avatarYingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: default avatarXuting Zhou <xutingz@nvidia.com>
Co-authored-by: default avatarZhengHSI <zhenghsi@qq.com>
parent a05bd83a
...@@ -5,16 +5,23 @@ import torch ...@@ -5,16 +5,23 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8, sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
) )
logger = logging.getLogger(__name__)
try:
from deep_gemm import ceil_div
except ImportError:
logger.error(f"Failed to import ceil_div from deep_gemm.")
import triton.language as tl
@triton.jit @triton.jit
...@@ -704,3 +711,334 @@ def grouped_gemm_triton( ...@@ -704,3 +711,334 @@ def grouped_gemm_triton(
**config, **config,
) )
return c return c
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0,
)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(
m_indices_start_ptr + start_m + off_expert,
cur_expert,
)
@triton.jit
def _fwd_kernel_ep_scatter_2(
total_token_num,
expert_start_loc,
recv_x,
recv_x_stride0,
recv_x_stride1,
recv_x_scale,
recv_x_scale_stride0,
recv_x_scale_stride1,
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
output_tensor_scale,
output_tensor_scale_stride0,
output_tensor_scale_stride1,
output_index,
output_index_stride0,
output_index_stride1,
topk_num: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num):
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
)
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
)
output_tensor_scale_ptr = (
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
BLOCK_E = 128 # token num of per expert is aligned to 128
BLOCK_D = 128 # block size of quantization
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@triton.jit
def _fwd_kernel_ep_gather(
total_token_num,
input_tensor,
input_tensor_stride0,
input_tensor_stride1,
recv_topk_ids,
recv_topk_ids_stride0,
recv_topk_ids_stride1,
recv_topk_weight,
recv_topk_weight_stride0,
recv_topk_weight_stride1,
input_index,
input_index_stride0,
input_index_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
topk_num: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block = tl.program_id(0)
start_cur_token = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num):
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if expert_id >= 0:
source_token_index = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
tmp = tl.load(
input_tensor
+ source_token_index * input_tensor_stride0
+ cur_block * BLOCK_D
+ off_d
)
accumulator += tmp.to(tl.float32) * acc_weight
tl.store(
output_tensor
+ cur_token * output_tensor_stride0
+ cur_block * BLOCK_D
+ off_d,
accumulator.to(output_tensor.dtype.element_ty),
)
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
output_tensor: torch.Tensor,
):
BLOCK_D = 1024 # block size of quantization
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
input_tensor.stride(0),
input_tensor.stride(1),
recv_topk_ids,
recv_topk_ids.stride(0),
recv_topk_ids.stride(1),
recv_topk_weight,
recv_topk_weight.stride(0),
recv_topk_weight.stride(1),
input_index,
input_index.stride(0),
input_index.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
topk_num=recv_topk_ids.shape[1],
num_warps=num_warps,
BLOCK_D=BLOCK_D,
)
return
# copy from
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
@triton.jit
def _tma_align_input_scale_kernel(
input_scale_ptr,
output_ptr,
m,
k_div_block_size,
input_scale_stride_m,
input_scale_stride_k,
output_stride_m,
output_stride_k,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
grid_m = tl.num_programs(0)
k_offsets = tl.arange(0, BLOCK_SIZE_K)
for m_base in range(pid_m, m, grid_m):
input_offset = (
input_scale_ptr
+ m_base * input_scale_stride_m
+ k_offsets * input_scale_stride_k
)
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
output_offset = (
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
)
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
def tma_align_input_scale(input_scale: torch.Tensor):
assert input_scale.dim() == 2
m, k_div_block_size = input_scale.shape
padd_m = get_tma_aligned_size(m, input_scale.element_size())
output = torch.empty(
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
)
grid_m = min(m, 8192)
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
_tma_align_input_scale_kernel[(grid_m,)](
input_scale_ptr=input_scale,
output_ptr=output,
m=m,
k_div_block_size=k_div_block_size,
input_scale_stride_m=input_scale.stride(0),
input_scale_stride_k=input_scale.stride(1),
output_stride_m=output.stride(1), # Note: these are swapped
output_stride_k=output.stride(0), # for column-major
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return output.t()[:m]
...@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple ...@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
try: try:
from deep_gemm import ( from deep_gemm import (
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked, m_grouped_gemm_fp8_fp8_bf16_nt_masked,
) )
from sgl_kernel import silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deep_gemm = True use_deep_gemm = True
except ImportError: except ImportError:
...@@ -20,6 +28,8 @@ from sglang.srt.distributed import ( ...@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel, gelu_and_mul_triton_kernel,
grouped_gemm_triton, grouped_gemm_triton,
post_reorder_triton_kernel, post_reorder_triton_kernel,
...@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
...@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE): ...@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
reorder_topk_ids: torch.Tensor, reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor, seg_indptr: torch.Tensor,
masked_m: torch.Tensor, masked_m: torch.Tensor,
expected_m: int, expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode, forward_mode: ForwardMode,
): ):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) if _ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency: elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else: else:
...@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE): ...@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
) )
return down_output return down_output
def forward_deepgemm_contiguous(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int],
):
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states_fp8.bfloat16()
M, K = hidden_states_fp8.size()
N = self.w13_weight.size(1)
scale_block_size = 128
gather_out = torch.empty_like(
hidden_states_fp8,
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype,
),
torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_idx)
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states_fp8,
hidden_states_scale,
topk_idx,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size
)
down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
down_output,
m_indices,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
return gather_out
def forward_deepgemm_masked( def forward_deepgemm_masked(
self, self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
......
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import DeepEPMode from sglang.srt.utils import DeepEPMode
try: try:
from deep_ep import Buffer from deep_ep import Buffer
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deepep = True use_deepep = True
except ImportError: except ImportError:
use_deepep = False use_deepep = False
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -78,7 +83,6 @@ class DeepEPBuffer: ...@@ -78,7 +83,6 @@ class DeepEPBuffer:
), ),
num_rdma_bytes, num_rdma_bytes,
) )
cls._buffer = Buffer( cls._buffer = Buffer(
group, group,
num_nvl_bytes, num_nvl_bytes,
...@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
if _ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
( if _ENABLE_JIT_DEEPGEMM:
hidden_states, (
topk_idx, hidden_states,
topk_weights, topk_idx,
event, topk_weights,
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event) num_recv_tokens_per_expert_list,
event.current_stream_wait() if self.async_finish else () event,
if hidden_states.shape[0] > 0: ) = self._dispatch_core(
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( hidden_states, topk_idx, topk_weights, previous_event
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
) )
else: event.current_stream_wait() if self.async_finish else ()
reorder_topk_ids = torch.empty( return (
(0,), device=hidden_states.device, dtype=torch.int64 hidden_states,
topk_idx,
topk_weights,
None,
num_recv_tokens_per_expert_list,
None,
None,
None,
) )
seg_indptr = torch.zeros( else:
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 (
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, previous_event
) )
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
else:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
masked_m = expected_m = None masked_m = expected_m = None
return (
return ( hidden_states,
hidden_states, topk_idx,
topk_idx, topk_weights,
topk_weights, reorder_topk_ids,
reorder_topk_ids, None,
seg_indptr, seg_indptr,
masked_m, masked_m,
expected_m, expected_m,
) )
def _dispatch_core( def _dispatch_core(
self, self,
x: torch.Tensor, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
previous_event, previous_event,
...@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
_, # num_recv_tokens_per_expert_list num_recv_tokens_per_expert_list,
self.handle, self.handle,
event, event,
) = buffer.dispatch( ) = buffer.dispatch(
...@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
previous_event=previous_event, previous_event=previous_event,
async_finish=self.async_finish, async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
) )
return ( return (
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert_list,
event, event,
) )
...@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
if hidden_states.shape[0] > 0: if _ENABLE_JIT_DEEPGEMM:
num_tokens = self.src2dst.shape[0] // self.router_topk output = hidden_states
output = torch.empty(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else: else:
output = torch.zeros( if hidden_states.shape[0] > 0:
(0, hidden_states.shape[1]), num_tokens = self.src2dst.shape[0] // self.router_topk
device=hidden_states.device, output = torch.empty(
dtype=hidden_states.dtype, (num_tokens, hidden_states.shape[1]),
) device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event return output, previous_event
...@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def _get_buffer(self): def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_normal() DeepEPBuffer.set_dispatch_mode_as_normal()
return DeepEPBuffer.get_deepep_buffer( return DeepEPBuffer.get_deepep_buffer(
self.group, self.group,
self.hidden_size, self.hidden_size,
...@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx, topk_idx,
topk_weights, topk_weights,
reorder_topk_ids, reorder_topk_ids,
None,
seg_indptr, seg_indptr,
masked_m, masked_m,
expected_m, expected_m,
...@@ -570,7 +611,8 @@ class DeepEPDispatcher: ...@@ -570,7 +611,8 @@ class DeepEPDispatcher:
def dispatch(self, *args, **kwargs) -> Tuple: def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs) self.dispatch_a(*args, **kwargs)
return self.dispatch_b() ret = self.dispatch_b()
return ret
def dispatch_a( def dispatch_a(
self, self,
...@@ -593,7 +635,8 @@ class DeepEPDispatcher: ...@@ -593,7 +635,8 @@ class DeepEPDispatcher:
def combine(self, *args, **kwargs) -> Tuple: def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs) self.combine_a(*args, **kwargs)
return self.combine_b() ret = self.combine_b()
return ret
def combine_a( def combine_a(
self, self,
......
...@@ -28,6 +28,11 @@ if is_cuda(): ...@@ -28,6 +28,11 @@ if is_cuda():
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"): if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
_ENABLE_JIT_DEEPGEMM = True _ENABLE_JIT_DEEPGEMM = True
def get_enable_jit_deepgemm():
return _ENABLE_JIT_DEEPGEMM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
......
...@@ -308,8 +308,8 @@ def sglang_per_token_group_quant_fp8( ...@@ -308,8 +308,8 @@ def sglang_per_token_group_quant_fp8(
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,
) )
if x.shape[0] > 0:
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s return x_q, x_s
......
...@@ -357,6 +357,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -357,6 +357,7 @@ class DeepseekV2MoE(nn.Module):
topk_idx, topk_idx,
topk_weights, topk_weights,
reorder_topk_ids, reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr, seg_indptr,
masked_m, masked_m,
expected_m, expected_m,
...@@ -368,10 +369,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -368,10 +369,13 @@ class DeepseekV2MoE(nn.Module):
) )
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids, reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr, seg_indptr=seg_indptr,
masked_m=masked_m, masked_m=masked_m,
expected_m=expected_m, expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
if self.ep_size > 1: if self.ep_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment