Unverified Commit 77e929a1 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support async DeepEP by splitting into two stages (#4995)

parent febe21ce
...@@ -113,7 +113,7 @@ class _DeepEPDispatcherImplBase: ...@@ -113,7 +113,7 @@ class _DeepEPDispatcherImplBase:
self.handle = None self.handle = None
def dispatch( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -123,12 +123,18 @@ class _DeepEPDispatcherImplBase: ...@@ -123,12 +123,18 @@ class _DeepEPDispatcherImplBase:
): ):
raise NotImplementedError raise NotImplementedError
def combine( def dispatch_b(self, *args, **kwargs):
raise NotImplementedError
def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
) -> torch.Tensor: ):
raise NotImplementedError
def combine_b(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -142,7 +148,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -142,7 +148,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
self.async_finish = async_finish self.async_finish = async_finish
self.src2dst = None self.src2dst = None
def dispatch( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -151,12 +157,20 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -151,12 +157,20 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
): ):
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, num_experts, previous_event
def dispatch_b(
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
):
( (
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
event, event,
) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) ) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, num_experts, previous_event
)
event.current_stream_wait() if self.async_finish else () event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
...@@ -187,15 +201,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -187,15 +201,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expected_m, expected_m,
) )
def _dispatch_normal( def _dispatch_core(
self, self,
x: torch.Tensor, x: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, num_experts: int,
previous_event,
): ):
previous_event = Buffer.capture() if self.async_finish else None
( (
num_tokens_per_rank, num_tokens_per_rank,
num_tokens_per_rdma_rank, num_tokens_per_rdma_rank,
...@@ -279,12 +292,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -279,12 +292,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
) )
return reorder_topk_ids, seg_indptr, gateup_input return reorder_topk_ids, seg_indptr, gateup_input
def combine( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
) -> torch.Tensor: ):
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty( output = torch.empty(
...@@ -308,16 +321,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -308,16 +321,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
hidden_states, event = self._combine_normal( previous_event = Buffer.capture() if self.async_finish else None
output, return output, previous_event
)
event.current_stream_wait() if self.async_finish else ()
def combine_b(self, output, previous_event):
hidden_states, event = self._combine_core(output, previous_event)
event.current_stream_wait() if self.async_finish else ()
return hidden_states return hidden_states
def _combine_normal(self, x: torch.Tensor): def _combine_core(self, x: torch.Tensor, previous_event):
previous_event = Buffer.capture() if self.async_finish else None
combined_x, _, event = self.buffer_normal.combine( combined_x, _, event = self.buffer_normal.combine(
x, x,
self.handle, self.handle,
...@@ -346,7 +358,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -346,7 +358,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
) )
self.return_recv_hook = return_recv_hook self.return_recv_hook = return_recv_hook
def dispatch( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -361,13 +373,33 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -361,13 +373,33 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
* topk_idx.shape[1] * topk_idx.shape[1]
+ num_experts + num_experts
) // num_experts ) // num_experts
hidden_states, masked_m, event, hook = self._dispatch_low_latency( hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states, hidden_states,
topk_idx, topk_idx,
num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank,
num_experts, num_experts,
use_fp8=True, use_fp8=True,
) )
return (
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
event,
hook,
)
def dispatch_b(
self,
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
event,
hook,
):
hook() if self.return_recv_hook else event.current_stream_wait() hook() if self.return_recv_hook else event.current_stream_wait()
# TODO # TODO
...@@ -389,7 +421,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -389,7 +421,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
expected_m, expected_m,
) )
def _dispatch_low_latency( def _dispatch_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -443,22 +475,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -443,22 +475,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
) )
return packed_recv_hidden, packed_recv_count, event, hook return packed_recv_hidden, packed_recv_count, event, hook
def combine( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
) -> torch.Tensor: ):
hidden_states, event, hook = self._combine_low_latency( hidden_states, event, hook = self._combine_core(
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
) )
hook() if self.return_recv_hook else event.current_stream_wait() return hidden_states, event, hook
def combine_b(self, hidden_states, event, hook):
hook() if self.return_recv_hook else event.current_stream_wait()
return hidden_states return hidden_states
def _combine_low_latency( def _combine_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -514,7 +548,11 @@ class DeepEPDispatcher: ...@@ -514,7 +548,11 @@ class DeepEPDispatcher:
**common_kwargs, **common_kwargs,
) )
def dispatch( def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs)
return self.dispatch_b()
def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
...@@ -522,29 +560,45 @@ class DeepEPDispatcher: ...@@ -522,29 +560,45 @@ class DeepEPDispatcher:
num_experts: int, num_experts: int,
num_max_dispatch_tokens_per_rank: int = 128, num_max_dispatch_tokens_per_rank: int = 128,
forward_mode: ForwardMode = None, forward_mode: ForwardMode = None,
) -> Tuple: ):
return self._get_dispatcher(forward_mode).dispatch( inner_state = self._get_impl(forward_mode).dispatch_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
num_experts=num_experts, num_experts=num_experts,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
) )
self._dispatch_intermediate_state = forward_mode, inner_state
def dispatch_b(self):
forward_mode, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state
return self._get_impl(forward_mode).dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
return self.combine_b()
def combine( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode, forward_mode: ForwardMode,
) -> torch.Tensor: ):
return self._get_dispatcher(forward_mode).combine( inner_state = self._get_impl(forward_mode).combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
) )
self._combine_intermediate_state = forward_mode, inner_state
def combine_b(self):
forward_mode, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state)
def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase: def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher return self._normal_dispatcher
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment