Unverified Commit 23c764b1 authored by Jinyan Chen's avatar Jinyan Chen Committed by GitHub
Browse files

[Feature] Support DeepEP Low Latency (#4767)


Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
Co-authored-by: default avatarlaixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 87fafa01
...@@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models. * `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`. * `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. * `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.
## Memory and scheduling ## Memory and scheduling
......
...@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel( ...@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_n,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (
output_scale_ptr
+ expert_id * stride_output_scale_0
+ hidden_dim_block_index * stride_output_scale_2
)
for token_index in tl.range(
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
):
gate = tl.load(
input_ptr_offs + token_index * stride_input_1,
mask=offs_in_d < size_n,
other=0.0,
).to(tl.float32)
up = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_n,
mask=offs_in_d < size_n,
other=0.0,
)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty
)
tl.store(
output_ptr_offs + token_index * stride_output_1,
output_q,
mask=offs_in_d < size_n,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s,
)
def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
output_scale: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
):
"""
input shape [expert_num, token_num_padded, hidden_dim]
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
quant_group_size int,
masked_m shape [expert_num],
"""
assert input.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert output.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
size_n = input.shape[-1] // 2
assert size_n % quant_group_size == 0
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 32
BLOCK_N = quant_group_size
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
assert BLOCK_N % quant_group_size == 0
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
size_n,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
return
@triton.jit @triton.jit
def tanh(x): def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1 return 2 * tl.sigmoid(2 * x) - 1
......
...@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple ...@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
# TODO: use deep_gemm masked kernel after low latency dispatch try:
# import deep_gemm from deep_gemm import (
# from deep_gemm import ( get_col_major_tma_aligned_tensor,
# get_col_major_tma_aligned_tensor, m_grouped_gemm_fp8_fp8_bf16_nt_masked,
# m_grouped_gemm_fp8_fp8_bf16_nt_masked, )
# )
use_deep_gemm = True
except ImportError:
use_deep_gemm = False
from torch.nn import Module from torch.nn import Module
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
...@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel, pre_reorder_triton_kernel,
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
) )
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
...@@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE): ...@@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
activation: str = "silu", activation: str = "silu",
deepep_mode: str = "auto",
): ):
super().__init__( super().__init__(
num_experts, num_experts,
...@@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE): ...@@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE):
custom_routing_function, custom_routing_function,
activation, activation,
) )
self.deepep_mode = deepep_mode
if self.deepep_mode in ["low_latency", "auto"]:
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor, reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor, seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
forward_mode: ForwardMode, forward_mode: ForwardMode,
): ):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode) if self.deepep_mode == "normal" or (
if True: # not forward_mode.is_decode(): self.deepep_mode == "auto" and not forward_mode.is_decode()
):
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else: else:
return self.forward_deepgemm_masked( raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
hidden_states, reorder_topk_ids, seg_indptr
)
def forward_normal( def forward_normal(
self, self,
...@@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE): ...@@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE):
def forward_deepgemm_masked( def forward_deepgemm_masked(
self, self,
hidden_states: torch.Tensor, hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
reorder_topk_ids: torch.Tensor, masked_m: torch.Tensor,
seg_indptr: torch.Tensor, expected_m: int,
): ):
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.activation == "silu"
assert (
if self.activation_scheme == "dynamic" and not self.use_block_quant: hidden_states_fp8[0].size(0) % 4 == 0
max_value = ( ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
# GroupGemm-0 # GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty( gateup_output = torch.empty(
hidden_states.shape[0], (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
self.w13_weight.shape[1], )
device=hidden_states.device, m_grouped_gemm_fp8_fp8_bf16_nt_masked(
dtype=hidden_states.dtype, hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
) )
if hidden_states.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
hidden_states = (
hidden_states[0],
get_col_major_tma_aligned_tensor(hidden_states[1]),
)
"""
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states, self.w13_weight, out, masked_m, expected_m
)
"""
# Act # Act
down_input = torch.empty( down_input = torch.empty(
gateup_output.shape[0], (
gateup_output.shape[1] // 2, gateup_output.shape[0],
device=gateup_output.device, gateup_output.shape[1],
dtype=( gateup_output.shape[2] // 2,
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
), ),
device=gateup_output.device,
dtype=self.fp8_dtype,
) )
if self.w2_input_scale is None and not self.use_block_quant: scale_block_size = 128
self.w2_input_scale = torch.ones( down_input_scale = torch.empty(
self.num_experts_per_partition, (
dtype=torch.float32, gateup_output.shape[0],
device=hidden_states.device,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1], gateup_output.shape[1],
reorder_topk_ids, gateup_output.shape[2] // 2 // scale_block_size,
self.w2_input_scale, ),
0, device=gateup_output.device,
self.num_experts_per_partition - 1, dtype=torch.float32,
BLOCK_SIZE=512, )
) silu_and_mul_masked_post_quant_fwd(
else: gateup_output,
raise ValueError(f"Unsupported activation: {self.activation=}") down_input,
down_input_scale,
scale_block_size,
masked_m,
)
# GroupGemm-1 # GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty( down_output = torch.empty(
down_input.shape[0], (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
self.w2_weight.shape[1], )
device=hidden_states.device, m_grouped_gemm_fp8_fp8_bf16_nt_masked(
dtype=hidden_states.dtype, down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
) )
if down_input.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
down_input = (
down_input[0],
get_col_major_tma_aligned_tensor(down_input[1]),
)
"""
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input, self.w2_weight, out, masked_m, expected_m
)
"""
return down_output return down_output
...@@ -76,8 +76,7 @@ def get_buffer_low_latency( ...@@ -76,8 +76,7 @@ def get_buffer_low_latency(
assert num_experts % group.size() == 0 assert num_experts % group.size() == 0
_buffer_low_latency = Buffer( _buffer_low_latency = Buffer(
group, group,
0, num_rdma_bytes=num_rdma_bytes,
num_rdma_bytes,
low_latency_mode=True, low_latency_mode=True,
num_qps_per_rank=num_experts // group.size(), num_qps_per_rank=num_experts // group.size(),
) )
...@@ -95,62 +94,63 @@ class DeepEPDispatcher: ...@@ -95,62 +94,63 @@ class DeepEPDispatcher:
group: torch.distributed.ProcessGroup, group: torch.distributed.ProcessGroup,
router_topk: int, router_topk: int,
permute_fusion: bool = False, permute_fusion: bool = False,
capacity_factor: float = None,
num_experts: int = None, num_experts: int = None,
num_local_experts: int = None, num_local_experts: int = None,
hidden_size: int = None, hidden_size: int = None,
params_dtype: torch.dtype = None, params_dtype: torch.dtype = None,
deepep_mode: str = "auto",
async_finish: bool = False, async_finish: bool = False,
return_recv_hook: bool = False,
): ):
if not use_deepep:
raise ImportError(
"DeepEP is not installed. Please install DeepEP package from "
"https://github.com/deepseek-ai/deepep."
)
self.group = group self.group = group
self.router_topk = router_topk self.router_topk = router_topk
self.capacity_factor = capacity_factor
self.permute_fusion = permute_fusion self.permute_fusion = permute_fusion
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_local_experts self.num_local_experts = num_local_experts
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.recv_expert_count = None
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.params_bytes = 2 self.params_bytes = 2
# Metadata
self.token_indices = None
self.token_probs = None
# Handle used for combine operation
self.handle = None
self.async_finish = async_finish
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 self.deepep_mode = deepep_mode
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding self.handle = None
self.num_max_dispatch_tokens_per_rank = 128
if not use_deepep: if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode
raise ImportError( self.buffer_normal = get_buffer_normal(
"DeepEP is not installed. Please install DeepEP package from " self.group, self.hidden_size * self.params_bytes
"https://github.com/deepseek-ai/deepep."
) )
self.buffer_normal = get_buffer_normal( self.async_finish = async_finish
self.group, self.hidden_size * self.params_bytes self.src2dst = None
) if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode
self.buffer_low_latency = None """
# Todo: enable low latency dispatch num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
""" https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
self.buffer_low_latency = get_buffer_low_latency( """
self.group, # TODO(ch-wan): allow users to set this value
self.num_max_dispatch_tokens_per_rank, self.num_max_dispatch_tokens_per_rank = 128
self.hidden_size * self.params_bytes, self.buffer_low_latency = get_buffer_low_latency(
self.num_experts, self.group,
) self.num_max_dispatch_tokens_per_rank,
""" self.hidden_size,
self.num_experts,
)
self.return_recv_hook = return_recv_hook
def deepep_permute( def deepep_permute(
self, self,
hidden_states, hidden_states: torch.Tensor,
fp8_dtype=None, topk_idx: torch.Tensor,
use_fp8_w8a8=False, fp8_dtype: Optional[torch.dtype] = None,
use_block_quant=False, use_fp8_w8a8: bool = False,
use_block_quant: bool = False,
): ):
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess( reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
self.topk_idx, self.num_experts topk_idx, self.num_experts
) )
num_total_tokens = reorder_topk_ids.numel() num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty( gateup_input = torch.empty(
...@@ -166,14 +166,13 @@ class DeepEPDispatcher: ...@@ -166,14 +166,13 @@ class DeepEPDispatcher:
deepep_permute_triton_kernel[(hidden_states.shape[0],)]( deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states, hidden_states,
gateup_input, gateup_input,
src2dst, self.src2dst,
self.topk_idx, topk_idx,
None, None,
self.router_topk, self.router_topk,
hidden_states.shape[1], hidden_states.shape[1],
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
self.src2dst = src2dst
return reorder_topk_ids, seg_indptr, gateup_input return reorder_topk_ids, seg_indptr, gateup_input
def dispatch( def dispatch(
...@@ -182,54 +181,64 @@ class DeepEPDispatcher: ...@@ -182,54 +181,64 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, num_experts: int,
forward_mode: ForwardMode,
num_max_dispatch_tokens_per_rank: int = 128, num_max_dispatch_tokens_per_rank: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]: forward_mode: ForwardMode = None,
) -> Tuple:
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
# Todo: enable low latency dispatch reorder_topk_ids = torch.empty(
if True: # not forward_mode.is_decode(): (0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
masked_m = torch.empty(
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
)
expected_m = 0
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
( (
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
num_recv_tokens_per_expert_list,
handle,
event, event,
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
self.tokens_per_expert = torch.tensor( event.current_stream_wait() if self.async_finish else ()
num_recv_tokens_per_expert_list, if hidden_states.shape[0] > 0:
device=hidden_states.device, reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
dtype=torch.int64, hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
else:
hidden_states, recv_expert_count, handle, event, hook = (
self.dispatch_low_latency(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
) )
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
expected_m = (
hidden_states.shape[0]
* self.buffer_low_latency.group_size
* topk_idx.shape[1]
+ num_experts
) // num_experts
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True,
) )
self.recv_expert_count = recv_expert_count hook() if self.return_recv_hook else event.current_stream_wait()
if self.async_finish:
event.current_stream_wait()
self.handle = handle
self.topk_idx = topk_idx
self.topk_weights = topk_weights
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
hidden_states, fp8_dtype=hidden_states.dtype
)
else: else:
reorder_topk_ids = torch.empty( raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
(0,), device=hidden_states.device, dtype=torch.int64
) return (
seg_indptr = torch.zeros( hidden_states,
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64 topk_idx,
) topk_weights,
return hidden_states, reorder_topk_ids, seg_indptr reorder_topk_ids,
seg_indptr,
masked_m,
expected_m,
)
def dispatch_normal( def dispatch_normal(
self, self,
...@@ -254,12 +263,15 @@ class DeepEPDispatcher: ...@@ -254,12 +263,15 @@ class DeepEPDispatcher:
allocate_on_comm_stream=previous_event is not None, allocate_on_comm_stream=previous_event is not None,
) )
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works.
( (
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert_list, _, # num_recv_tokens_per_expert_list
handle, self.handle,
event, event,
) = self.buffer_normal.dispatch( ) = self.buffer_normal.dispatch(
x, x,
...@@ -278,8 +290,6 @@ class DeepEPDispatcher: ...@@ -278,8 +290,6 @@ class DeepEPDispatcher:
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert_list,
handle,
event, event,
) )
...@@ -289,18 +299,19 @@ class DeepEPDispatcher: ...@@ -289,18 +299,19 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
num_experts: int, num_experts: int,
use_fp8: bool = False,
): ):
""" """
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch' # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall! # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782 # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
+
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
index f60e933..cddaabf 100644 index 76ae2e2..8ecd08f 100644
--- a/csrc/kernels/internode_ll.cu --- a/csrc/kernels/internode_ll.cu
+++ b/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases) { void* workspace, cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9; constexpr int kNumMaxTopK = 9;
- constexpr int kNumWarpsPerGroup = 10; - constexpr int kNumWarpsPerGroup = 10;
...@@ -308,16 +319,9 @@ class DeepEPDispatcher: ...@@ -308,16 +319,9 @@ class DeepEPDispatcher:
+ constexpr int kNumWarpsPerGroup = 8; + constexpr int kNumWarpsPerGroup = 8;
+ constexpr int kNumWarpGroups = 4; + constexpr int kNumWarpGroups = 4;
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
+
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = cell_div(num_experts, kNumWarpGroups); @@ -501,8 +501,8 @@ void combine(void* combined_x,
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
+
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
@@ -505,8 +505,8 @@ void combine(void* combined_x,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) { void* workspace, cudaStream_t stream, int phases) {
...@@ -326,28 +330,33 @@ class DeepEPDispatcher: ...@@ -326,28 +330,33 @@ class DeepEPDispatcher:
+ constexpr int kNumWarpsPerGroup = 8; + constexpr int kNumWarpsPerGroup = 8;
+ constexpr int kNumWarpGroups = 4; + constexpr int kNumWarpGroups = 4;
constexpr int kNumMaxTopk = 9; constexpr int kNumMaxTopk = 9;
+
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
""" """
recv_hidden_states, recv_expert_count, handle, event, hook = ( packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
self.buffer_low_latency.low_latency_dispatch( self.buffer_low_latency.low_latency_dispatch(
hidden_states, hidden_states,
topk_idx, topk_idx,
num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank,
num_experts, num_experts,
async_finish=self.async_finish, use_fp8=use_fp8,
return_recv_hook=False, # True for double-batch overlapping, need call hook() async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
) )
) )
# hook() return packed_recv_hidden, packed_recv_count, event, hook
return recv_hidden_states, recv_expert_count, handle, event, hook
def combine( def combine(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode self,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states: torch.Tensor,
# Todo: enable low latency combine topk_idx: torch.Tensor,
if True: # not forward_mode.is_decode(): topk_weights: torch.Tensor,
forward_mode: ForwardMode,
) -> torch.Tensor:
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty( output = torch.empty(
...@@ -359,8 +368,8 @@ class DeepEPDispatcher: ...@@ -359,8 +368,8 @@ class DeepEPDispatcher:
hidden_states, hidden_states,
output, output,
self.src2dst, self.src2dst,
self.topk_idx, topk_idx,
self.topk_weights, topk_weights,
self.router_topk, self.router_topk,
hidden_states.shape[1], hidden_states.shape[1],
BLOCK_SIZE=512, BLOCK_SIZE=512,
...@@ -371,24 +380,30 @@ class DeepEPDispatcher: ...@@ -371,24 +380,30 @@ class DeepEPDispatcher:
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
hidden_states, event = self.combine_normal(output, self.handle) hidden_states, event = self.combine_normal(
else: output,
)
event.current_stream_wait() if self.async_finish else ()
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
hidden_states, event, hook = self.combine_low_latency( hidden_states, event, hook = self.combine_low_latency(
hidden_states, self.topk_idx, self.topk_weights, self.handle hidden_states,
topk_idx,
topk_weights,
) )
hook() if self.return_recv_hook else event.current_stream_wait()
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
if self.async_finish:
event.current_stream_wait()
self.handle = None
return hidden_states return hidden_states
def combine_normal(self, x: torch.Tensor, handle: Tuple): def combine_normal(self, x: torch.Tensor):
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine( combined_x, _, event = self.buffer_normal.combine(
x, x,
handle, self.handle,
async_finish=self.async_finish, async_finish=self.async_finish,
previous_event=previous_event, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None, allocate_on_comm_stream=previous_event is not None,
...@@ -400,17 +415,15 @@ class DeepEPDispatcher: ...@@ -400,17 +415,15 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
handle: Tuple,
): ):
combined_hidden_states, event_overlap, hook = ( combined_hidden_states, event, hook = (
self.buffer_low_latency.low_latency_combine( self.buffer_low_latency.low_latency_combine(
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
handle, self.handle,
async_finish=self.async_finish, async_finish=not self.return_recv_hook,
return_recv_hook=False, # True for double-batch overlapping, need call hook() return_recv_hook=self.return_recv_hook,
) )
) )
# hook() return combined_hidden_states, event, hook
return combined_hidden_states, event_overlap, hook
...@@ -72,6 +72,7 @@ global_server_args_dict = { ...@@ -72,6 +72,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_deepep_moe": ServerArgs.enable_deepep_moe,
"deepep_mode": ServerArgs.deepep_mode,
"device": ServerArgs.device, "device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
......
...@@ -147,6 +147,7 @@ class ModelRunner: ...@@ -147,6 +147,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe, "enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_mode": server_args.deepep_mode,
"device": server_args.device, "device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
...@@ -272,7 +273,7 @@ class ModelRunner: ...@@ -272,7 +273,7 @@ class ModelRunner:
server_args.disable_radix_cache = True server_args.disable_radix_cache = True
if server_args.enable_deepep_moe: if server_args.enable_deepep_moe:
logger.info("DeepEP is turned on.") logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
......
...@@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module): ...@@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
) )
self.experts = MoEImpl( if not global_server_args_dict["enable_deepep_moe"]:
num_experts=config.n_routed_experts, self.experts = MoEImpl(
top_k=config.num_experts_per_tok, num_experts=config.n_routed_experts,
hidden_size=config.hidden_size, top_k=config.num_experts_per_tok,
intermediate_size=config.moe_intermediate_size, hidden_size=config.hidden_size,
renormalize=config.norm_topk_prob, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, renormalize=config.norm_topk_prob,
use_grouped_topk=True, quant_config=quant_config,
num_expert_group=config.n_group, use_grouped_topk=True,
topk_group=config.topk_group, num_expert_group=config.n_group,
correction_bias=self.gate.e_score_correction_bias, topk_group=config.topk_group,
prefix=add_prefix("experts", prefix), correction_bias=self.gate.e_score_correction_bias,
) prefix=add_prefix("experts", prefix),
)
else:
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
deepep_mode=global_server_args_dict["deepep_mode"],
)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
...@@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
) )
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.n_routed_experts self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
...@@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True, # TODO async_finish=True, # TODO
return_recv_hook=True,
) )
def forward( def forward(
...@@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module): ...@@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
) )
if self.tp_size > 1: if self.ep_size > 1:
recv_hidden_states, reorder_topk_ids, seg_indptr = ( (
self.deepep_dispatcher.dispatch( hidden_states,
hidden_states, topk_idx,
topk_idx, topk_weights,
topk_weights, reorder_topk_ids,
self.num_experts, seg_indptr,
forward_mode, masked_m,
) expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
self.num_experts,
forward_mode=forward_mode,
) )
final_hidden_states = ( final_hidden_states = (
self.experts( self.experts(
hidden_states=recv_hidden_states, hidden_states=hidden_states,
reorder_topk_ids=reorder_topk_ids, reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr, seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
* self.routed_scaling_factor * self.routed_scaling_factor
) )
if self.tp_size > 1: if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states, forward_mode final_hidden_states,
topk_idx,
topk_weights,
forward_mode,
) )
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
......
...@@ -161,6 +161,7 @@ class ServerArgs: ...@@ -161,6 +161,7 @@ class ServerArgs:
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[str] = "auto"
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None cuda_graph_max_bs: Optional[int] = None
...@@ -285,6 +286,13 @@ class ServerArgs: ...@@ -285,6 +286,13 @@ class ServerArgs:
if self.grammar_backend is None: if self.grammar_backend is None:
self.grammar_backend = "xgrammar" self.grammar_backend = "xgrammar"
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
...@@ -300,6 +308,10 @@ class ServerArgs: ...@@ -300,6 +308,10 @@ class ServerArgs:
self.enable_sp_layernorm = False self.enable_sp_layernorm = False
# DeepEP MoE # DeepEP MoE
if self.enable_deepep_moe: if self.enable_deepep_moe:
if self.deepep_mode == "auto":
assert (
not self.enable_dp_attention
), "DeepEP MoE `auto` mode is not supported with DP Attention."
self.ep_size = self.tp_size self.ep_size = self.tp_size
self.enable_sp_layernorm = ( self.enable_sp_layernorm = (
self.dp_size < self.tp_size if self.enable_dp_attention else True self.dp_size < self.tp_size if self.enable_dp_attention else True
...@@ -1082,6 +1094,12 @@ class ServerArgs: ...@@ -1082,6 +1094,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enabling DeepEP MoE implementation for EP MoE.", help="Enabling DeepEP MoE implementation for EP MoE.",
) )
parser.add_argument(
"--deepep-mode",
type=str,
choices=["normal", "low_latency", "auto"],
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
# Server warmups # Server warmups
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment