"tests/vscode:/vscode.git/clone" did not exist on "c8b3b299c9f3142546e0a41f835e561af1aaffb7"
Unverified Commit 4685a630 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Model Bash][DeepSeekR1] Remove Shared Expert Clone (#34344)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent ee1d25f1
...@@ -240,24 +240,22 @@ class DefaultMoERunner(MoERunner): ...@@ -240,24 +240,22 @@ class DefaultMoERunner(MoERunner):
) )
) )
hidden_states_clone: torch.Tensor | None = None shared_experts_input: torch.Tensor | None = None
if use_shared_experts_stream: if use_shared_experts_stream:
assert self.shared_experts_stream is not None assert self.shared_experts_stream is not None
assert self.moe_config.disable_inplace
shared_experts_input = ( shared_experts_input = (
shared_input if shared_input is not None else hidden_states shared_input if shared_input is not None else hidden_states
) )
# Clone BEFORE switching streams to avoid race condition # Record that the shared_experts_input will be used in the
# where routed_expert kernel may mutate hidden_states. # shared_experts_stream to to avoid gc issue from
hidden_states_clone = shared_experts_input.clone() # deallocation. For more details:
# https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream()) # NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output. # because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream) shared_experts_input.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts # Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the # stream here since we want to run in parallel with the
...@@ -265,7 +263,7 @@ class DefaultMoERunner(MoERunner): ...@@ -265,7 +263,7 @@ class DefaultMoERunner(MoERunner):
assert self.shared_experts_stream is not None assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream()) self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, hidden_states_clone return use_shared_experts_stream, shared_experts_input
def ensure_dp_chunking_init(self): def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None: if not self.use_dp_chunking or self.batched_hidden_states is not None:
...@@ -584,7 +582,7 @@ class DefaultMoERunner(MoERunner): ...@@ -584,7 +582,7 @@ class DefaultMoERunner(MoERunner):
use_chunked_impl = self.use_dp_chunking use_chunked_impl = self.use_dp_chunking
use_shared_experts_stream, hidden_states_clone = ( use_shared_experts_stream, shared_experts_input = (
self._maybe_setup_shared_experts_stream( self._maybe_setup_shared_experts_stream(
hidden_states, hidden_states,
shared_input, shared_input,
...@@ -726,7 +724,7 @@ class DefaultMoERunner(MoERunner): ...@@ -726,7 +724,7 @@ class DefaultMoERunner(MoERunner):
with torch.cuda.stream(self.shared_experts_stream): with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid # Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream # conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone) shared_output = self.shared_experts(shared_experts_input)
current_stream().wait_stream(self.shared_experts_stream) current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
......
...@@ -175,7 +175,7 @@ class MiniCPMMoE(nn.Module): ...@@ -175,7 +175,7 @@ class MiniCPMMoE(nn.Module):
) )
final_hidden_states = fused_experts( final_hidden_states = fused_experts(
hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=True hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=False
) )
if self.tp_size > 1: if self.tp_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment