Unverified Commit febe21ce authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Small refactor DeepEPDispatcher into subclasses (#4994)

parent a995a773
...@@ -23,7 +23,7 @@ _buffer_normal = None ...@@ -23,7 +23,7 @@ _buffer_normal = None
_buffer_low_latency = None _buffer_low_latency = None
def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
""" """
Copy from DeepEP example usage in model inference prefilling. Copy from DeepEP example usage in model inference prefilling.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
...@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): ...@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
return _buffer_normal return _buffer_normal
def get_buffer_low_latency( def _get_buffer_low_latency(
group: dist.ProcessGroup, group: dist.ProcessGroup,
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
hidden: int, hidden: int,
...@@ -85,24 +85,16 @@ def get_buffer_low_latency( ...@@ -85,24 +85,16 @@ def get_buffer_low_latency(
return _buffer_low_latency return _buffer_low_latency
class DeepEPDispatcher: class _DeepEPDispatcherImplBase:
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
def __init__( def __init__(
self, self,
group: torch.distributed.ProcessGroup, group: torch.distributed.ProcessGroup,
router_topk: int, router_topk: int,
permute_fusion: bool = False, permute_fusion: bool,
num_experts: int = None, num_experts: int,
num_local_experts: int = None, num_local_experts: int,
hidden_size: int = None, hidden_size: int,
params_dtype: torch.dtype = None, params_dtype: torch.dtype,
deepep_mode: DeepEPMode = DeepEPMode.auto,
async_finish: bool = False,
return_recv_hook: bool = False,
): ):
if not use_deepep: if not use_deepep:
raise ImportError( raise ImportError(
...@@ -119,63 +111,36 @@ class DeepEPDispatcher: ...@@ -119,63 +111,36 @@ class DeepEPDispatcher:
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.params_bytes = 2 self.params_bytes = 2
self.deepep_mode = deepep_mode
self.handle = None self.handle = None
if self.deepep_mode.enable_normal(): def dispatch(
self.buffer_normal = get_buffer_normal(
self.group, self.hidden_size * self.params_bytes
)
self.async_finish = async_finish
self.src2dst = None
if self.deepep_mode.enable_low_latency():
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self.num_max_dispatch_tokens_per_rank = 128
self.buffer_low_latency = get_buffer_low_latency(
self.group,
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
self.return_recv_hook = return_recv_hook
def deepep_permute(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
fp8_dtype: Optional[torch.dtype] = None, topk_weights: torch.Tensor,
use_fp8_w8a8: bool = False, num_experts: int,
use_block_quant: bool = False, num_max_dispatch_tokens_per_rank: int,
): ):
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( raise NotImplementedError
topk_idx, self.num_experts
) def combine(
num_total_tokens = reorder_topk_ids.numel() self,
gateup_input = torch.empty( hidden_states: torch.Tensor,
(int(num_total_tokens), hidden_states.shape[1]), topk_idx: torch.Tensor,
device=hidden_states.device, topk_weights: torch.Tensor,
dtype=( ) -> torch.Tensor:
fp8_dtype raise NotImplementedError
if (use_fp8_w8a8 and not use_block_quant)
else hidden_states.dtype
), class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
) def __init__(self, async_finish: bool, **kwargs):
# PreReorder super().__init__(**kwargs)
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states, self.buffer_normal = _get_buffer_normal(
gateup_input, self.group, self.hidden_size * self.params_bytes
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
) )
return reorder_topk_ids, seg_indptr, gateup_input self.async_finish = async_finish
self.src2dst = None
def dispatch( def dispatch(
self, self,
...@@ -183,51 +148,34 @@ class DeepEPDispatcher: ...@@ -183,51 +148,34 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, num_experts: int,
num_max_dispatch_tokens_per_rank: int = 128, num_max_dispatch_tokens_per_rank: int,
forward_mode: ForwardMode = None, ):
) -> Tuple:
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
reorder_topk_ids = torch.empty( (
(0,), device=hidden_states.device, dtype=torch.int64 hidden_states,
) topk_idx,
seg_indptr = torch.zeros( topk_weights,
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64 event,
) ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
masked_m = torch.empty( event.current_stream_wait() if self.async_finish else ()
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 if hidden_states.shape[0] > 0:
) reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
expected_m = 0 hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
(
hidden_states,
topk_idx,
topk_weights,
event,
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
elif resolved_deepep_mode == DeepEPMode.low_latency:
expected_m = (
hidden_states.shape[0]
* self.buffer_low_latency.group_size
* topk_idx.shape[1]
+ num_experts
) // num_experts
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True,
) )
hook() if self.return_recv_hook else event.current_stream_wait()
else: else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
# TODO
# masked_m = torch.empty(
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
# )
# expected_m = 0
masked_m = expected_m = None
return ( return (
hidden_states, hidden_states,
...@@ -239,7 +187,7 @@ class DeepEPDispatcher: ...@@ -239,7 +187,7 @@ class DeepEPDispatcher:
expected_m, expected_m,
) )
def dispatch_normal( def _dispatch_normal(
self, self,
x: torch.Tensor, x: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -292,7 +240,156 @@ class DeepEPDispatcher: ...@@ -292,7 +240,156 @@ class DeepEPDispatcher:
event, event,
) )
def dispatch_low_latency( def _deepep_permute(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
fp8_dtype: Optional[torch.dtype] = None,
use_fp8_w8a8: bool = False,
use_block_quant: bool = False,
):
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_block_quant)
else hidden_states.dtype
),
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input
def combine(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
) -> torch.Tensor:
if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states, event = self._combine_normal(
output,
)
event.current_stream_wait() if self.async_finish else ()
return hidden_states
def _combine_normal(self, x: torch.Tensor):
previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine(
x,
self.handle,
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
)
return combined_x, event
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def __init__(self, return_recv_hook: bool, **kwargs):
super().__init__(**kwargs)
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self.num_max_dispatch_tokens_per_rank = 128
self.buffer_low_latency = _get_buffer_low_latency(
self.group,
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
)
self.return_recv_hook = return_recv_hook
def dispatch(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int,
):
topk_idx = topk_idx.to(torch.int64)
expected_m = (
hidden_states.shape[0]
* self.buffer_low_latency.group_size
* topk_idx.shape[1]
+ num_experts
) // num_experts
hidden_states, masked_m, event, hook = self._dispatch_low_latency(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True,
)
hook() if self.return_recv_hook else event.current_stream_wait()
# TODO
# reorder_topk_ids = torch.empty(
# (0,), device=hidden_states.device, dtype=torch.int64
# )
# seg_indptr = torch.zeros(
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
# )
reorder_topk_ids = seg_indptr = None
return (
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
seg_indptr,
masked_m,
expected_m,
)
def _dispatch_low_latency(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -351,62 +448,17 @@ class DeepEPDispatcher: ...@@ -351,62 +448,17 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode,
) -> torch.Tensor: ) -> torch.Tensor:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) hidden_states, event, hook = self._combine_low_latency(
if resolved_deepep_mode == DeepEPMode.normal: hidden_states,
if hidden_states.shape[0] > 0: topk_idx,
num_tokens = self.src2dst.shape[0] // self.router_topk topk_weights,
output = torch.empty( )
(num_tokens, hidden_states.shape[1]), hook() if self.return_recv_hook else event.current_stream_wait()
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states, event = self.combine_normal(
output,
)
event.current_stream_wait() if self.async_finish else ()
elif resolved_deepep_mode == DeepEPMode.low_latency:
hidden_states, event, hook = self.combine_low_latency(
hidden_states,
topk_idx,
topk_weights,
)
hook() if self.return_recv_hook else event.current_stream_wait()
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
return hidden_states return hidden_states
def combine_normal(self, x: torch.Tensor): def _combine_low_latency(
previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine(
x,
self.handle,
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
)
return combined_x, event
def combine_low_latency(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -423,3 +475,80 @@ class DeepEPDispatcher: ...@@ -423,3 +475,80 @@ class DeepEPDispatcher:
) )
) )
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
class DeepEPDispatcher:
def __init__(
self,
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool = False,
num_experts: int = None,
num_local_experts: int = None,
hidden_size: int = None,
params_dtype: torch.dtype = None,
deepep_mode: DeepEPMode = DeepEPMode.auto,
async_finish: bool = False,
return_recv_hook: bool = False,
):
self.deepep_mode = deepep_mode
common_kwargs = dict(
group=group,
router_topk=router_topk,
permute_fusion=permute_fusion,
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
)
if self.deepep_mode.enable_normal():
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
async_finish=async_finish,
**common_kwargs,
)
if self.deepep_mode.enable_low_latency():
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
return_recv_hook=return_recv_hook,
**common_kwargs,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_max_dispatch_tokens_per_rank: int = 128,
forward_mode: ForwardMode = None,
) -> Tuple:
return self._get_dispatcher(forward_mode).dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
num_experts=num_experts,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
)
def combine(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_mode: ForwardMode,
) -> torch.Tensor:
return self._get_dispatcher(forward_mode).combine(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self._low_latency_dispatcher
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
...@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module): ...@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
) )
if not global_server_args_dict["enable_deepep_moe"]: self.experts = MoEImpl(
self.experts = MoEImpl( num_experts=config.n_routed_experts,
num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok,
top_k=config.num_experts_per_tok, hidden_size=config.hidden_size,
hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size,
intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob,
renormalize=config.norm_topk_prob, quant_config=quant_config,
quant_config=quant_config, use_grouped_topk=True,
use_grouped_topk=True, num_expert_group=config.n_group,
num_expert_group=config.n_group, topk_group=config.topk_group,
topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias,
correction_bias=self.gate.e_score_correction_bias, prefix=add_prefix("experts", prefix),
prefix=add_prefix("experts", prefix), **(
) dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
else: if global_server_args_dict["enable_deepep_moe"]
self.experts = MoEImpl( else {}
num_experts=config.n_routed_experts, ),
top_k=config.num_experts_per_tok, )
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment