Unverified Commit 23c764b1 authored by Jinyan Chen's avatar Jinyan Chen Committed by GitHub
Browse files

[Feature] Support DeepEP Low Latency (#4767)


Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
Co-authored-by: default avatarlaixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 87fafa01
......@@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.
## Memory and scheduling
......
......@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_n,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (
output_scale_ptr
+ expert_id * stride_output_scale_0
+ hidden_dim_block_index * stride_output_scale_2
)
for token_index in tl.range(
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
):
gate = tl.load(
input_ptr_offs + token_index * stride_input_1,
mask=offs_in_d < size_n,
other=0.0,
).to(tl.float32)
up = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_n,
mask=offs_in_d < size_n,
other=0.0,
)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty
)
tl.store(
output_ptr_offs + token_index * stride_output_1,
output_q,
mask=offs_in_d < size_n,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s,
)
def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
output_scale: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
):
"""
input shape [expert_num, token_num_padded, hidden_dim]
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
quant_group_size int,
masked_m shape [expert_num],
"""
assert input.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert output.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
size_n = input.shape[-1] // 2
assert size_n % quant_group_size == 0
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 32
BLOCK_N = quant_group_size
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
assert BLOCK_N % quant_group_size == 0
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
size_n,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
return
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
......
......@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
import torch
# TODO: use deep_gemm masked kernel after low latency dispatch
# import deep_gemm
# from deep_gemm import (
# get_col_major_tma_aligned_tensor,
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
# )
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
use_deep_gemm = True
except ImportError:
use_deep_gemm = False
from torch.nn import Module
from sglang.srt.custom_op import CustomOp
......@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
......@@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE):
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
deepep_mode: str = "auto",
):
super().__init__(
num_experts,
......@@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE):
custom_routing_function,
activation,
)
self.deepep_mode = deepep_mode
if self.deepep_mode in ["low_latency", "auto"]:
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
def forward(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
forward_mode: ForwardMode,
):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
if True: # not forward_mode.is_decode():
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
return self.forward_deepgemm_masked(
hidden_states, reorder_topk_ids, seg_indptr
)
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
def forward_normal(
self,
......@@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE):
def forward_deepgemm_masked(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
masked_m: torch.Tensor,
expected_m: int,
):
assert self.quant_method is not None
assert self.activation == "silu"
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
assert (
hidden_states_fp8[0].size(0) % 4 == 0
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
# GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
if hidden_states.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
hidden_states = (
hidden_states[0],
get_col_major_tma_aligned_tensor(hidden_states[1]),
)
"""
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states, self.w13_weight, out, masked_m, expected_m
)
"""
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_experts_per_partition - 1,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)
if down_input.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
down_input = (
down_input[0],
get_col_major_tma_aligned_tensor(down_input[1]),
)
"""
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input, self.w2_weight, out, masked_m, expected_m
)
"""
return down_output
......@@ -76,8 +76,7 @@ def get_buffer_low_latency(
assert num_experts % group.size() == 0
_buffer_low_latency = Buffer(
group,
0,
num_rdma_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // group.size(),
)
......@@ -95,62 +94,63 @@ class DeepEPDispatcher:
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool = False,
capacity_factor: float = None,
num_experts: int = None,
num_local_experts: int = None,
hidden_size: int = None,
params_dtype: torch.dtype = None,
deepep_mode: str = "auto",
async_finish: bool = False,
return_recv_hook: bool = False,
):
if not use_deepep:
raise ImportError(
"DeepEP is not installed. Please install DeepEP package from "
"https://github.com/deepseek-ai/deepep."
)
self.group = group
self.router_topk = router_topk
self.capacity_factor = capacity_factor
self.permute_fusion = permute_fusion
self.num_experts = num_experts
self.num_local_experts = num_local_experts
self.hidden_size = hidden_size
self.recv_expert_count = None
self.params_dtype = params_dtype
self.params_bytes = 2
# Metadata
self.token_indices = None
self.token_probs = None
# Handle used for combine operation
self.handle = None
self.async_finish = async_finish
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
self.num_max_dispatch_tokens_per_rank = 128
self.deepep_mode = deepep_mode
self.handle = None
if not use_deepep:
raise ImportError(
"DeepEP is not installed. Please install DeepEP package from "
"https://github.com/deepseek-ai/deepep."
if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode
self.buffer_normal = get_buffer_normal(
self.group, self.hidden_size * self.params_bytes
)
self.buffer_normal = get_buffer_normal(
self.group, self.hidden_size * self.params_bytes
)
self.buffer_low_latency = None
# Todo: enable low latency dispatch
"""
self.buffer_low_latency = get_buffer_low_latency(
self.group,
self.num_max_dispatch_tokens_per_rank,
self.hidden_size * self.params_bytes,
self.num_experts,
)
"""
self.async_finish = async_finish
self.src2dst = None
if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self.num_max_dispatch_tokens_per_rank = 128
self.buffer_low_latency = get_buffer_low_latency(
self.group,
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
self.return_recv_hook = return_recv_hook
def deepep_permute(
self,
hidden_states,
fp8_dtype=None,
use_fp8_w8a8=False,
use_block_quant=False,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
fp8_dtype: Optional[torch.dtype] = None,
use_fp8_w8a8: bool = False,
use_block_quant: bool = False,
):
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
self.topk_idx, self.num_experts
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
......@@ -166,14 +166,13 @@ class DeepEPDispatcher:
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
self.topk_idx,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
self.src2dst = src2dst
return reorder_topk_ids, seg_indptr, gateup_input
def dispatch(
......@@ -182,54 +181,64 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
forward_mode: ForwardMode,
num_max_dispatch_tokens_per_rank: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_mode: ForwardMode = None,
) -> Tuple:
topk_idx = topk_idx.to(torch.int64)
# Todo: enable low latency dispatch
if True: # not forward_mode.is_decode():
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
masked_m = torch.empty(
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
)
expected_m = 0
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
handle,
event,
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
self.tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list,
device=hidden_states.device,
dtype=torch.int64,
)
else:
hidden_states, recv_expert_count, handle, event, hook = (
self.dispatch_low_latency(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
expected_m = (
hidden_states.shape[0]
* self.buffer_low_latency.group_size
* topk_idx.shape[1]
+ num_experts
) // num_experts
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True,
)
self.recv_expert_count = recv_expert_count
if self.async_finish:
event.current_stream_wait()
self.handle = handle
self.topk_idx = topk_idx
self.topk_weights = topk_weights
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
hidden_states, fp8_dtype=hidden_states.dtype
)
hook() if self.return_recv_hook else event.current_stream_wait()
else:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
return hidden_states, reorder_topk_ids, seg_indptr
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
return (
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
seg_indptr,
masked_m,
expected_m,
)
def dispatch_normal(
self,
......@@ -254,12 +263,15 @@ class DeepEPDispatcher:
allocate_on_comm_stream=previous_event is not None,
)
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works.
(
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
handle,
_, # num_recv_tokens_per_expert_list
self.handle,
event,
) = self.buffer_normal.dispatch(
x,
......@@ -278,8 +290,6 @@ class DeepEPDispatcher:
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
handle,
event,
)
......@@ -289,18 +299,19 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int,
use_fp8: bool = False,
):
"""
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
# Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
+
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
index f60e933..cddaabf 100644
index 76ae2e2..8ecd08f 100644
--- a/csrc/kernels/internode_ll.cu
+++ b/csrc/kernels/internode_ll.cu
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks,
@@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
- constexpr int kNumWarpsPerGroup = 10;
......@@ -308,16 +319,9 @@ class DeepEPDispatcher:
+ constexpr int kNumWarpsPerGroup = 8;
+ constexpr int kNumWarpGroups = 4;
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
+
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
+
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
@@ -505,8 +505,8 @@ void combine(void* combined_x,
@@ -501,8 +501,8 @@ void combine(void* combined_x,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) {
......@@ -326,28 +330,33 @@ class DeepEPDispatcher:
+ constexpr int kNumWarpsPerGroup = 8;
+ constexpr int kNumWarpGroups = 4;
constexpr int kNumMaxTopk = 9;
+
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
"""
recv_hidden_states, recv_expert_count, handle, event, hook = (
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
self.buffer_low_latency.low_latency_dispatch(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
async_finish=self.async_finish,
return_recv_hook=False, # True for double-batch overlapping, need call hook()
use_fp8=use_fp8,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
)
# hook()
return recv_hidden_states, recv_expert_count, handle, event, hook
return packed_recv_hidden, packed_recv_count, event, hook
def combine(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Todo: enable low latency combine
if True: # not forward_mode.is_decode():
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_mode: ForwardMode,
) -> torch.Tensor:
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty(
......@@ -359,8 +368,8 @@ class DeepEPDispatcher:
hidden_states,
output,
self.src2dst,
self.topk_idx,
self.topk_weights,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
......@@ -371,24 +380,30 @@ class DeepEPDispatcher:
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states, event = self.combine_normal(output, self.handle)
else:
hidden_states, event = self.combine_normal(
output,
)
event.current_stream_wait() if self.async_finish else ()
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
hidden_states, event, hook = self.combine_low_latency(
hidden_states, self.topk_idx, self.topk_weights, self.handle
hidden_states,
topk_idx,
topk_weights,
)
hook() if self.return_recv_hook else event.current_stream_wait()
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
if self.async_finish:
event.current_stream_wait()
self.handle = None
return hidden_states
def combine_normal(self, x: torch.Tensor, handle: Tuple):
def combine_normal(self, x: torch.Tensor):
previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine(
x,
handle,
self.handle,
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
......@@ -400,17 +415,15 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
handle: Tuple,
):
combined_hidden_states, event_overlap, hook = (
combined_hidden_states, event, hook = (
self.buffer_low_latency.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
handle,
async_finish=self.async_finish,
return_recv_hook=False, # True for double-batch overlapping, need call hook()
self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
)
# hook()
return combined_hidden_states, event_overlap, hook
return combined_hidden_states, event, hook
......@@ -72,6 +72,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"deepep_mode": ServerArgs.deepep_mode,
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
......
......@@ -147,6 +147,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_mode": server_args.deepep_mode,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
......@@ -272,7 +273,7 @@ class ModelRunner:
server_args.disable_radix_cache = True
if server_args.enable_deepep_moe:
logger.info("DeepEP is turned on.")
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
......
......@@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
)
if not global_server_args_dict["enable_deepep_moe"]:
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
)
else:
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
deepep_mode=global_server_args_dict["deepep_mode"],
)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
......@@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
)
if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob
......@@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True, # TODO
return_recv_hook=True,
)
def forward(
......@@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
)
if self.tp_size > 1:
recv_hidden_states, reorder_topk_ids, seg_indptr = (
self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
self.num_experts,
forward_mode,
)
if self.ep_size > 1:
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
self.num_experts,
forward_mode=forward_mode,
)
final_hidden_states = (
self.experts(
hidden_states=recv_hidden_states,
hidden_states=hidden_states,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
forward_mode=forward_mode,
)
* self.routed_scaling_factor
)
if self.tp_size > 1:
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states, forward_mode
final_hidden_states,
topk_idx,
topk_weights,
forward_mode,
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
......
......@@ -161,6 +161,7 @@ class ServerArgs:
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[str] = "auto"
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
......@@ -285,6 +286,13 @@ class ServerArgs:
if self.grammar_backend is None:
self.grammar_backend = "xgrammar"
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Data parallelism attention
if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
......@@ -300,6 +308,10 @@ class ServerArgs:
self.enable_sp_layernorm = False
# DeepEP MoE
if self.enable_deepep_moe:
if self.deepep_mode == "auto":
assert (
not self.enable_dp_attention
), "DeepEP MoE `auto` mode is not supported with DP Attention."
self.ep_size = self.tp_size
self.enable_sp_layernorm = (
self.dp_size < self.tp_size if self.enable_dp_attention else True
......@@ -1082,6 +1094,12 @@ class ServerArgs:
action="store_true",
help="Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--deepep-mode",
type=str,
choices=["normal", "low_latency", "auto"],
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
# Server warmups
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment