Unverified Commit 6ff9c6a5 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Cleanup unused resources after DeepEP operation (#4996)

parent 77e929a1
......@@ -184,11 +184,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
# TODO
# masked_m = torch.empty(
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
# )
# expected_m = 0
masked_m = expected_m = None
return (
......@@ -327,6 +322,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def combine_b(self, output, previous_event):
hidden_states, event = self._combine_core(output, previous_event)
event.current_stream_wait() if self.async_finish else ()
self.handle = None
self.src2dst = None
return hidden_states
def _combine_core(self, x: torch.Tensor, previous_event):
......@@ -402,13 +399,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
):
hook() if self.return_recv_hook else event.current_stream_wait()
# TODO
# reorder_topk_ids = torch.empty(
# (0,), device=hidden_states.device, dtype=torch.int64
# )
# seg_indptr = torch.zeros(
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
# )
reorder_topk_ids = seg_indptr = None
return (
......@@ -508,6 +498,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
return_recv_hook=self.return_recv_hook,
)
)
self.handle = None
return combined_hidden_states, event, hook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment