Unverified Commit bc3f6db2 authored by Jinyan Chen's avatar Jinyan Chen Committed by GitHub
Browse files

[Fix] DeepEP Compatibility with Low Latency (#5068)


Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent aac531c5
...@@ -7,6 +7,7 @@ try: ...@@ -7,6 +7,7 @@ try:
except ImportError: except ImportError:
use_deepep = False use_deepep = False
from enum import IntEnum, auto
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
...@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
_buffer_normal = None
_buffer_low_latency = None
class DeepEPDispatchMode(IntEnum):
NORMAL = auto()
LOW_LATENCY = auto()
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
"""
Copy from DeepEP example usage in model inference prefilling.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
"""
global _buffer_normal class DeepEPBuffer:
_buffer: Optional[Buffer] = None
_dispatch_mode: Optional[DeepEPDispatchMode] = None
_hidden_size: Optional[int] = None
_num_max_dispatch_tokens_per_rank: Optional[int] = None
_num_experts: Optional[int] = None
@classmethod
def get_deepep_buffer(
cls,
group: dist.ProcessGroup,
hidden_size: int,
param_bytes: int,
deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = None,
num_experts: int = None,
):
if cls._buffer is not None:
return cls._buffer
cls._hidden_size = hidden_size
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
cls._num_experts = num_experts
num_nvl_bytes, num_rdma_bytes = 0, 0 num_nvl_bytes, num_rdma_bytes = 0, 0
if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes
for config in ( for config in (
Buffer.get_dispatch_config(group.size()), Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()), Buffer.get_combine_config(group.size()),
): ):
num_nvl_bytes = max( num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
num_nvl_bytes,
) )
num_rdma_bytes = max( num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
num_rdma_bytes,
) )
if deepep_mode.enable_low_latency():
if ( assert num_max_dispatch_tokens_per_rank is not None
_buffer_normal is None assert num_experts is not None and num_experts % group.size() == 0
or _buffer_normal.group != group num_rdma_bytes = max(
or _buffer_normal.num_nvl_bytes < num_nvl_bytes Buffer.get_low_latency_rdma_size_hint(
or _buffer_normal.num_rdma_bytes < num_rdma_bytes num_max_dispatch_tokens_per_rank,
): hidden_size,
_buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes) group.size(),
return _buffer_normal num_experts,
),
num_rdma_bytes,
def _get_buffer_low_latency(
group: dist.ProcessGroup,
num_max_dispatch_tokens_per_rank: int,
hidden: int,
num_experts: int,
):
"""
Copy from DeepEP example usage in model inference decoding.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
global _buffer_low_latency
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
) )
if ( cls._buffer = Buffer(
_buffer_low_latency is None
or _buffer_low_latency.group != group
or not _buffer_low_latency.low_latency_mode
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
):
assert num_experts % group.size() == 0
_buffer_low_latency = Buffer(
group, group,
num_rdma_bytes=num_rdma_bytes, num_nvl_bytes,
low_latency_mode=True, num_rdma_bytes,
num_qps_per_rank=num_experts // group.size(), low_latency_mode=deepep_mode.enable_low_latency(),
num_qps_per_rank=(
num_experts // group.size() if deepep_mode.enable_low_latency() else 1
),
)
return cls._buffer
@classmethod
def clean_buffer(cls):
if not cls._buffer.low_latency_mode:
return
cls._buffer.clean_low_latency_buffer(
cls._num_max_dispatch_tokens_per_rank,
cls._hidden_size,
cls._num_experts,
) )
return _buffer_low_latency
@classmethod
def set_dispatch_mode_as_normal(cls):
cls._dispatch_mode = DeepEPDispatchMode.NORMAL
@classmethod
def set_dispatch_mode_as_low_latency(cls):
if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
cls.clean_buffer()
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPDispatcherImplBase: class _DeepEPDispatcherImplBase:
...@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase: ...@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
num_local_experts: int, num_local_experts: int,
hidden_size: int, hidden_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
deepep_mode: DeepEPMode,
): ):
if not use_deepep: if not use_deepep:
raise ImportError( raise ImportError(
...@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase: ...@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
self.num_local_experts = num_local_experts self.num_local_experts = num_local_experts
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.deepep_mode = deepep_mode
self.params_bytes = 2 self.params_bytes = 2
self.num_max_dispatch_tokens_per_rank = 128
self.handle = None self.handle = None
...@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase: ...@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
): ):
raise NotImplementedError raise NotImplementedError
...@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase: ...@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
def combine_b(self, *args, **kwargs): def combine_b(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def _get_buffer(self) -> Buffer:
raise NotImplementedError
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def __init__(self, async_finish: bool, **kwargs): def __init__(self, async_finish: bool, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.buffer_normal = _get_buffer_normal(
self.group, self.hidden_size * self.params_bytes
)
self.async_finish = async_finish self.async_finish = async_finish
self.src2dst = None self.src2dst = None
...@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
): ):
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, num_experts, previous_event return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b( def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
):
( (
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
event, event,
) = self._dispatch_core( ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
hidden_states, topk_idx, topk_weights, num_experts, previous_event
)
event.current_stream_wait() if self.async_finish else () event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
...@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(0,), device=hidden_states.device, dtype=torch.int64 (0,), device=hidden_states.device, dtype=torch.int64
) )
seg_indptr = torch.zeros( seg_indptr = torch.zeros(
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64 (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
) )
masked_m = expected_m = None masked_m = expected_m = None
...@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
x: torch.Tensor, x: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int,
previous_event, previous_event,
): ):
buffer = self._get_buffer()
( (
num_tokens_per_rank, num_tokens_per_rank,
num_tokens_per_rdma_rank, num_tokens_per_rdma_rank,
num_tokens_per_expert, num_tokens_per_expert,
is_token_in_rank, is_token_in_rank,
previous_event, previous_event,
) = self.buffer_normal.get_dispatch_layout( ) = buffer.get_dispatch_layout(
topk_idx, topk_idx,
num_experts, self.num_experts,
previous_event=previous_event, previous_event=previous_event,
async_finish=self.async_finish, async_finish=self.async_finish,
allocate_on_comm_stream=previous_event is not None, allocate_on_comm_stream=previous_event is not None,
...@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# FIXME: `handle` should be transmitted with tokens from dispatch to combine. # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping # However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works. # `handle` as a member variable works.
( (
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
...@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
_, # num_recv_tokens_per_expert_list _, # num_recv_tokens_per_expert_list
self.handle, self.handle,
event, event,
) = self.buffer_normal.dispatch( ) = buffer.dispatch(
x, x,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
...@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return hidden_states return hidden_states
def _combine_core(self, x: torch.Tensor, previous_event): def _combine_core(self, x: torch.Tensor, previous_event):
combined_x, _, event = self.buffer_normal.combine( buffer = self._get_buffer()
combined_x, _, event = buffer.combine(
x, x,
self.handle, self.handle,
async_finish=self.async_finish, async_finish=self.async_finish,
...@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
) )
return combined_x, event return combined_x, event
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_normal()
return DeepEPBuffer.get_deepep_buffer(
self.group,
self.hidden_size,
self.params_bytes,
self.deepep_mode,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
)
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def __init__(self, return_recv_hook: bool, **kwargs): def __init__(self, return_recv_hook: bool, **kwargs):
...@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
""" """
# TODO(ch-wan): allow users to set this value
self.num_max_dispatch_tokens_per_rank = 128
self.buffer_low_latency = _get_buffer_low_latency(
self.group,
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
self.return_recv_hook = return_recv_hook self.return_recv_hook = return_recv_hook
def dispatch_a( def dispatch_a(
...@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
): ):
buffer = self._get_buffer()
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
expected_m = ( expected_m = (
hidden_states.shape[0] hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
* self.buffer_low_latency.group_size + self.num_experts
* topk_idx.shape[1] ) // self.num_experts
+ num_experts
) // num_experts
hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states, hidden_states,
topk_idx, topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True, use_fp8=True,
) )
return ( return (
...@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int,
use_fp8: bool = False, use_fp8: bool = False,
): ):
""" """
...@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
""" """
buffer = self._get_buffer()
packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
self.buffer_low_latency.low_latency_dispatch( buffer.low_latency_dispatch(
hidden_states, hidden_states,
topk_idx, topk_idx,
num_max_dispatch_tokens_per_rank, self.num_max_dispatch_tokens_per_rank,
num_experts, self.num_experts,
use_fp8=use_fp8, use_fp8=use_fp8,
async_finish=not self.return_recv_hook, async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook, return_recv_hook=self.return_recv_hook,
...@@ -488,8 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -488,8 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
combined_hidden_states, event, hook = ( buffer = self._get_buffer()
self.buffer_low_latency.low_latency_combine( combined_hidden_states, event, hook = buffer.low_latency_combine(
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
...@@ -497,10 +517,20 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -497,10 +517,20 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
async_finish=not self.return_recv_hook, async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook, return_recv_hook=self.return_recv_hook,
) )
)
self.handle = None self.handle = None
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_low_latency()
return DeepEPBuffer.get_deepep_buffer(
self.group,
self.hidden_size,
self.params_bytes,
self.deepep_mode,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
)
class DeepEPDispatcher: class DeepEPDispatcher:
def __init__( def __init__(
...@@ -526,18 +556,19 @@ class DeepEPDispatcher: ...@@ -526,18 +556,19 @@ class DeepEPDispatcher:
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
hidden_size=hidden_size, hidden_size=hidden_size,
params_dtype=params_dtype, params_dtype=params_dtype,
deepep_mode=deepep_mode,
) )
if self.deepep_mode.enable_normal():
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
async_finish=async_finish,
**common_kwargs,
)
if self.deepep_mode.enable_low_latency(): if self.deepep_mode.enable_low_latency():
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency( self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
return_recv_hook=return_recv_hook, return_recv_hook=return_recv_hook,
**common_kwargs, **common_kwargs,
) )
if self.deepep_mode.enable_normal():
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
async_finish=async_finish,
**common_kwargs,
)
def dispatch(self, *args, **kwargs) -> Tuple: def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs) self.dispatch_a(*args, **kwargs)
...@@ -548,16 +579,12 @@ class DeepEPDispatcher: ...@@ -548,16 +579,12 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int = 128,
forward_mode: ForwardMode = None, forward_mode: ForwardMode = None,
): ):
inner_state = self._get_impl(forward_mode).dispatch_a( inner_state = self._get_impl(forward_mode).dispatch_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
num_experts=num_experts,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
) )
self._dispatch_intermediate_state = forward_mode, inner_state self._dispatch_intermediate_state = forward_mode, inner_state
...@@ -589,7 +616,7 @@ class DeepEPDispatcher: ...@@ -589,7 +616,7 @@ class DeepEPDispatcher:
del self._combine_intermediate_state del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state) return self._get_impl(forward_mode).combine_b(*inner_state)
def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher return self._normal_dispatcher
......
...@@ -72,7 +72,7 @@ class ForwardMode(IntEnum): ...@@ -72,7 +72,7 @@ class ForwardMode(IntEnum):
DUMMY_FIRST = auto() DUMMY_FIRST = auto()
def is_prefill(self): def is_prefill(self):
return self == ForwardMode.PREFILL return self.is_extend()
def is_extend(self): def is_extend(self):
return ( return (
......
...@@ -324,6 +324,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -324,6 +324,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
) )
if self.ep_size > 1: if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
( (
hidden_states, hidden_states,
topk_idx, topk_idx,
...@@ -336,7 +337,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -336,7 +337,6 @@ class DeepseekV2MoE(nn.Module):
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
self.num_experts,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
final_hidden_states = ( final_hidden_states = (
......
...@@ -1101,6 +1101,7 @@ class ServerArgs: ...@@ -1101,6 +1101,7 @@ class ServerArgs:
"--deepep-mode", "--deepep-mode",
type=str, type=str,
choices=["normal", "low_latency", "auto"], choices=["normal", "low_latency", "auto"],
default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment