Unverified Commit bc3f6db2 authored by Jinyan Chen's avatar Jinyan Chen Committed by GitHub
Browse files

[Fix] DeepEP Compatibility with Low Latency (#5068)


Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent aac531c5
......@@ -7,6 +7,7 @@ try:
except ImportError:
use_deepep = False
from enum import IntEnum, auto
from typing import Optional, Tuple
import torch
......@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
_buffer_normal = None
_buffer_low_latency = None
class DeepEPDispatchMode(IntEnum):
NORMAL = auto()
LOW_LATENCY = auto()
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
"""
Copy from DeepEP example usage in model inference prefilling.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
"""
global _buffer_normal
class DeepEPBuffer:
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (
Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()),
):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
)
num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
)
_buffer: Optional[Buffer] = None
_dispatch_mode: Optional[DeepEPDispatchMode] = None
_hidden_size: Optional[int] = None
_num_max_dispatch_tokens_per_rank: Optional[int] = None
_num_experts: Optional[int] = None
if (
_buffer_normal is None
or _buffer_normal.group != group
or _buffer_normal.num_nvl_bytes < num_nvl_bytes
or _buffer_normal.num_rdma_bytes < num_rdma_bytes
):
_buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer_normal
def _get_buffer_low_latency(
group: dist.ProcessGroup,
num_max_dispatch_tokens_per_rank: int,
hidden: int,
num_experts: int,
):
"""
Copy from DeepEP example usage in model inference decoding.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
global _buffer_low_latency
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
)
if (
_buffer_low_latency is None
or _buffer_low_latency.group != group
or not _buffer_low_latency.low_latency_mode
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
@classmethod
def get_deepep_buffer(
cls,
group: dist.ProcessGroup,
hidden_size: int,
param_bytes: int,
deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = None,
num_experts: int = None,
):
assert num_experts % group.size() == 0
_buffer_low_latency = Buffer(
if cls._buffer is not None:
return cls._buffer
cls._hidden_size = hidden_size
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
cls._num_experts = num_experts
num_nvl_bytes, num_rdma_bytes = 0, 0
if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes
for config in (
Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()),
):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
num_nvl_bytes,
)
num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
num_rdma_bytes,
)
if deepep_mode.enable_low_latency():
assert num_max_dispatch_tokens_per_rank is not None
assert num_experts is not None and num_experts % group.size() == 0
num_rdma_bytes = max(
Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank,
hidden_size,
group.size(),
num_experts,
),
num_rdma_bytes,
)
cls._buffer = Buffer(
group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // group.size(),
num_nvl_bytes,
num_rdma_bytes,
low_latency_mode=deepep_mode.enable_low_latency(),
num_qps_per_rank=(
num_experts // group.size() if deepep_mode.enable_low_latency() else 1
),
)
return _buffer_low_latency
return cls._buffer
@classmethod
def clean_buffer(cls):
if not cls._buffer.low_latency_mode:
return
cls._buffer.clean_low_latency_buffer(
cls._num_max_dispatch_tokens_per_rank,
cls._hidden_size,
cls._num_experts,
)
@classmethod
def set_dispatch_mode_as_normal(cls):
cls._dispatch_mode = DeepEPDispatchMode.NORMAL
@classmethod
def set_dispatch_mode_as_low_latency(cls):
if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
cls.clean_buffer()
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPDispatcherImplBase:
......@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
num_local_experts: int,
hidden_size: int,
params_dtype: torch.dtype,
deepep_mode: DeepEPMode,
):
if not use_deepep:
raise ImportError(
......@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
self.num_local_experts = num_local_experts
self.hidden_size = hidden_size
self.params_dtype = params_dtype
self.deepep_mode = deepep_mode
self.params_bytes = 2
self.num_max_dispatch_tokens_per_rank = 128
self.handle = None
......@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
):
raise NotImplementedError
......@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
def combine_b(self, *args, **kwargs):
raise NotImplementedError
def _get_buffer(self) -> Buffer:
raise NotImplementedError
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def __init__(self, async_finish: bool, **kwargs):
super().__init__(**kwargs)
self.buffer_normal = _get_buffer_normal(
self.group, self.hidden_size * self.params_bytes
)
self.async_finish = async_finish
self.src2dst = None
......@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
):
topk_idx = topk_idx.to(torch.int64)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, num_experts, previous_event
return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
):
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
(
hidden_states,
topk_idx,
topk_weights,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, num_experts, previous_event
)
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
......@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
masked_m = expected_m = None
......@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
x: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
previous_event,
):
buffer = self._get_buffer()
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
previous_event,
) = self.buffer_normal.get_dispatch_layout(
) = buffer.get_dispatch_layout(
topk_idx,
num_experts,
self.num_experts,
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=previous_event is not None,
......@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works.
(
recv_x,
recv_topk_idx,
......@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
_, # num_recv_tokens_per_expert_list
self.handle,
event,
) = self.buffer_normal.dispatch(
) = buffer.dispatch(
x,
topk_idx=topk_idx,
topk_weights=topk_weights,
......@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return hidden_states
def _combine_core(self, x: torch.Tensor, previous_event):
combined_x, _, event = self.buffer_normal.combine(
buffer = self._get_buffer()
combined_x, _, event = buffer.combine(
x,
self.handle,
async_finish=self.async_finish,
......@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
)
return combined_x, event
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_normal()
return DeepEPBuffer.get_deepep_buffer(
self.group,
self.hidden_size,
self.params_bytes,
self.deepep_mode,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
)
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def __init__(self, return_recv_hook: bool, **kwargs):
......@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self.num_max_dispatch_tokens_per_rank = 128
self.buffer_low_latency = _get_buffer_low_latency(
self.group,
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
self.return_recv_hook = return_recv_hook
def dispatch_a(
......@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
):
buffer = self._get_buffer()
topk_idx = topk_idx.to(torch.int64)
expected_m = (
hidden_states.shape[0]
* self.buffer_low_latency.group_size
* topk_idx.shape[1]
+ num_experts
) // num_experts
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
+ self.num_experts
) // self.num_experts
hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True,
)
return (
......@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int,
use_fp8: bool = False,
):
"""
......@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
"""
buffer = self._get_buffer()
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
self.buffer_low_latency.low_latency_dispatch(
buffer.low_latency_dispatch(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=use_fp8,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
......@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
combined_hidden_states, event, hook = (
self.buffer_low_latency.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
buffer = self._get_buffer()
combined_hidden_states, event, hook = buffer.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
self.handle = None
return combined_hidden_states, event, hook
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_low_latency()
return DeepEPBuffer.get_deepep_buffer(
self.group,
self.hidden_size,
self.params_bytes,
self.deepep_mode,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
)
class DeepEPDispatcher:
def __init__(
......@@ -526,18 +556,19 @@ class DeepEPDispatcher:
num_local_experts=num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=deepep_mode,
)
if self.deepep_mode.enable_normal():
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
async_finish=async_finish,
**common_kwargs,
)
if self.deepep_mode.enable_low_latency():
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
return_recv_hook=return_recv_hook,
**common_kwargs,
)
if self.deepep_mode.enable_normal():
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
async_finish=async_finish,
**common_kwargs,
)
def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs)
......@@ -548,16 +579,12 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int = 128,
forward_mode: ForwardMode = None,
):
inner_state = self._get_impl(forward_mode).dispatch_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
num_experts=num_experts,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
)
self._dispatch_intermediate_state = forward_mode, inner_state
......@@ -589,7 +616,7 @@ class DeepEPDispatcher:
del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state)
def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher
......
......@@ -72,7 +72,7 @@ class ForwardMode(IntEnum):
DUMMY_FIRST = auto()
def is_prefill(self):
return self == ForwardMode.PREFILL
return self.is_extend()
def is_extend(self):
return (
......
......@@ -324,6 +324,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.correction_bias,
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
......@@ -336,7 +337,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states,
topk_idx,
topk_weights,
self.num_experts,
forward_mode=forward_mode,
)
final_hidden_states = (
......
......@@ -1101,6 +1101,7 @@ class ServerArgs:
"--deepep-mode",
type=str,
choices=["normal", "low_latency", "auto"],
default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment