Unverified Commit ca75741e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support async in DeepEP (#4610)


Co-authored-by: default avatarCheng Wan <cwan39@gatech.edu>
parent c6d549e7
......@@ -5,7 +5,6 @@ try:
except ImportError:
use_deepep = False
import os
from typing import Optional, Tuple
import torch
......@@ -101,6 +100,7 @@ class DeepEPDispatcher:
num_local_experts: int = None,
hidden_size: int = None,
params_dtype: torch.dtype = None,
async_finish: bool = False,
):
self.group = group
self.router_topk = router_topk
......@@ -117,6 +117,7 @@ class DeepEPDispatcher:
self.token_probs = None
# Handle used for combine operation
self.handle = None
self.async_finish = async_finish
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
......@@ -182,7 +183,6 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor,
num_experts: int,
forward_mode: ForwardMode,
previous_event=None,
num_max_dispatch_tokens_per_rank: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_idx = topk_idx.to(torch.int64)
......@@ -195,9 +195,7 @@ class DeepEPDispatcher:
num_recv_tokens_per_expert_list,
handle,
event,
) = self.dispatch_normal(
hidden_states, topk_idx, topk_weights, num_experts, previous_event
)
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
self.tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list,
device=hidden_states.device,
......@@ -213,6 +211,10 @@ class DeepEPDispatcher:
)
)
self.recv_expert_count = recv_expert_count
if self.async_finish:
event.current_stream_wait()
self.handle = handle
self.topk_idx = topk_idx
self.topk_weights = topk_weights
......@@ -235,8 +237,9 @@ class DeepEPDispatcher:
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
previous_event=None,
):
previous_event = Buffer.capture() if self.async_finish else None
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
......@@ -247,8 +250,8 @@ class DeepEPDispatcher:
topk_idx,
num_experts,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
async_finish=self.async_finish,
allocate_on_comm_stream=previous_event is not None,
)
(
......@@ -267,8 +270,8 @@ class DeepEPDispatcher:
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
)
return (
......@@ -333,7 +336,7 @@ class DeepEPDispatcher:
topk_idx,
num_max_dispatch_tokens_per_rank,
num_experts,
async_finish=False,
async_finish=self.async_finish,
return_recv_hook=False, # True for double-batch overlapping, need call hook()
)
)
......@@ -373,16 +376,22 @@ class DeepEPDispatcher:
hidden_states, event, hook = self.combine_low_latency(
hidden_states, self.topk_idx, self.topk_weights, self.handle
)
if self.async_finish:
event.current_stream_wait()
self.handle = None
return hidden_states
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
def combine_normal(self, x: torch.Tensor, handle: Tuple):
previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine(
x,
handle,
async_finish=False,
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=False,
allocate_on_comm_stream=previous_event is not None,
)
return combined_x, event
......@@ -399,7 +408,7 @@ class DeepEPDispatcher:
topk_idx,
topk_weights,
handle,
async_finish=False,
async_finish=self.async_finish,
return_recv_hook=False, # True for double-batch overlapping, need call hook()
)
)
......
......@@ -239,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
async_finish=True, # TODO
)
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment