"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "18601691c2e83adcdf8e3e6bd88db41b0fd2fc2e"
Unverified Commit 6ff9c6a5 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Cleanup unused resources after DeepEP operation (#4996)

parent 77e929a1
...@@ -184,11 +184,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -184,11 +184,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64 (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
) )
# TODO
# masked_m = torch.empty(
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
# )
# expected_m = 0
masked_m = expected_m = None masked_m = expected_m = None
return ( return (
...@@ -327,6 +322,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -327,6 +322,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def combine_b(self, output, previous_event): def combine_b(self, output, previous_event):
hidden_states, event = self._combine_core(output, previous_event) hidden_states, event = self._combine_core(output, previous_event)
event.current_stream_wait() if self.async_finish else () event.current_stream_wait() if self.async_finish else ()
self.handle = None
self.src2dst = None
return hidden_states return hidden_states
def _combine_core(self, x: torch.Tensor, previous_event): def _combine_core(self, x: torch.Tensor, previous_event):
...@@ -402,13 +399,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -402,13 +399,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
): ):
hook() if self.return_recv_hook else event.current_stream_wait() hook() if self.return_recv_hook else event.current_stream_wait()
# TODO
# reorder_topk_ids = torch.empty(
# (0,), device=hidden_states.device, dtype=torch.int64
# )
# seg_indptr = torch.zeros(
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
# )
reorder_topk_ids = seg_indptr = None reorder_topk_ids = seg_indptr = None
return ( return (
...@@ -508,6 +498,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -508,6 +498,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
return_recv_hook=self.return_recv_hook, return_recv_hook=self.return_recv_hook,
) )
) )
self.handle = None
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment