Unverified Commit ca75741e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support async in DeepEP (#4610)


Co-authored-by: default avatarCheng Wan <cwan39@gatech.edu>
parent c6d549e7
...@@ -5,7 +5,6 @@ try: ...@@ -5,7 +5,6 @@ try:
except ImportError: except ImportError:
use_deepep = False use_deepep = False
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
...@@ -101,6 +100,7 @@ class DeepEPDispatcher: ...@@ -101,6 +100,7 @@ class DeepEPDispatcher:
num_local_experts: int = None, num_local_experts: int = None,
hidden_size: int = None, hidden_size: int = None,
params_dtype: torch.dtype = None, params_dtype: torch.dtype = None,
async_finish: bool = False,
): ):
self.group = group self.group = group
self.router_topk = router_topk self.router_topk = router_topk
...@@ -117,6 +117,7 @@ class DeepEPDispatcher: ...@@ -117,6 +117,7 @@ class DeepEPDispatcher:
self.token_probs = None self.token_probs = None
# Handle used for combine operation # Handle used for combine operation
self.handle = None self.handle = None
self.async_finish = async_finish
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
...@@ -182,7 +183,6 @@ class DeepEPDispatcher: ...@@ -182,7 +183,6 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, num_experts: int,
forward_mode: ForwardMode, forward_mode: ForwardMode,
previous_event=None,
num_max_dispatch_tokens_per_rank: int = 128, num_max_dispatch_tokens_per_rank: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
...@@ -195,9 +195,7 @@ class DeepEPDispatcher: ...@@ -195,9 +195,7 @@ class DeepEPDispatcher:
num_recv_tokens_per_expert_list, num_recv_tokens_per_expert_list,
handle, handle,
event, event,
) = self.dispatch_normal( ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
hidden_states, topk_idx, topk_weights, num_experts, previous_event
)
self.tokens_per_expert = torch.tensor( self.tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list, num_recv_tokens_per_expert_list,
device=hidden_states.device, device=hidden_states.device,
...@@ -213,6 +211,10 @@ class DeepEPDispatcher: ...@@ -213,6 +211,10 @@ class DeepEPDispatcher:
) )
) )
self.recv_expert_count = recv_expert_count self.recv_expert_count = recv_expert_count
if self.async_finish:
event.current_stream_wait()
self.handle = handle self.handle = handle
self.topk_idx = topk_idx self.topk_idx = topk_idx
self.topk_weights = topk_weights self.topk_weights = topk_weights
...@@ -235,8 +237,9 @@ class DeepEPDispatcher: ...@@ -235,8 +237,9 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, num_experts: int,
previous_event=None,
): ):
previous_event = Buffer.capture() if self.async_finish else None
( (
num_tokens_per_rank, num_tokens_per_rank,
num_tokens_per_rdma_rank, num_tokens_per_rdma_rank,
...@@ -247,8 +250,8 @@ class DeepEPDispatcher: ...@@ -247,8 +250,8 @@ class DeepEPDispatcher:
topk_idx, topk_idx,
num_experts, num_experts,
previous_event=previous_event, previous_event=previous_event,
async_finish=False, async_finish=self.async_finish,
allocate_on_comm_stream=False, allocate_on_comm_stream=previous_event is not None,
) )
( (
...@@ -267,8 +270,8 @@ class DeepEPDispatcher: ...@@ -267,8 +270,8 @@ class DeepEPDispatcher:
is_token_in_rank=is_token_in_rank, is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert, num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event, previous_event=previous_event,
async_finish=False, async_finish=self.async_finish,
allocate_on_comm_stream=False, allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
) )
return ( return (
...@@ -333,7 +336,7 @@ class DeepEPDispatcher: ...@@ -333,7 +336,7 @@ class DeepEPDispatcher:
topk_idx, topk_idx,
num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank,
num_experts, num_experts,
async_finish=False, async_finish=self.async_finish,
return_recv_hook=False, # True for double-batch overlapping, need call hook() return_recv_hook=False, # True for double-batch overlapping, need call hook()
) )
) )
...@@ -373,16 +376,22 @@ class DeepEPDispatcher: ...@@ -373,16 +376,22 @@ class DeepEPDispatcher:
hidden_states, event, hook = self.combine_low_latency( hidden_states, event, hook = self.combine_low_latency(
hidden_states, self.topk_idx, self.topk_weights, self.handle hidden_states, self.topk_idx, self.topk_weights, self.handle
) )
if self.async_finish:
event.current_stream_wait()
self.handle = None self.handle = None
return hidden_states return hidden_states
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None): def combine_normal(self, x: torch.Tensor, handle: Tuple):
previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine( combined_x, _, event = self.buffer_normal.combine(
x, x,
handle, handle,
async_finish=False, async_finish=self.async_finish,
previous_event=previous_event, previous_event=previous_event,
allocate_on_comm_stream=False, allocate_on_comm_stream=previous_event is not None,
) )
return combined_x, event return combined_x, event
...@@ -399,7 +408,7 @@ class DeepEPDispatcher: ...@@ -399,7 +408,7 @@ class DeepEPDispatcher:
topk_idx, topk_idx,
topk_weights, topk_weights,
handle, handle,
async_finish=False, async_finish=self.async_finish,
return_recv_hook=False, # True for double-batch overlapping, need call hook() return_recv_hook=False, # True for double-batch overlapping, need call hook()
) )
) )
......
...@@ -239,6 +239,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -239,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
async_finish=True, # TODO
) )
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment