Unverified Commit e605e8e3 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix] Fix Stream Sync for Shared Expert Overlap (#28430)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent bca74e32
...@@ -3,6 +3,3 @@ accuracy_threshold: 0.45 ...@@ -3,6 +3,3 @@ accuracy_threshold: 0.45
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
max_model_len: 4096 max_model_len: 4096
# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220
env:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1"
...@@ -2456,28 +2456,6 @@ class FusedMoE(CustomOp): ...@@ -2456,28 +2456,6 @@ class FusedMoE(CustomOp):
staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)
# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
self.shared_experts_stream.wait_stream(current_stream())
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states)
else:
shared_output = None
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
...@@ -2506,11 +2484,7 @@ class FusedMoE(CustomOp): ...@@ -2506,11 +2484,7 @@ class FusedMoE(CustomOp):
if has_separate_shared_experts: if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple) assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output = self.shared_experts(staged_hidden_states)
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
final_hidden_states, final_hidden_states,
...@@ -2619,11 +2593,22 @@ class FusedMoE(CustomOp): ...@@ -2619,11 +2593,22 @@ class FusedMoE(CustomOp):
assert self.shared_experts is not None assert self.shared_experts is not None
if self.shared_experts_stream is not None: if self.shared_experts_stream is not None:
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
self.shared_experts_stream.wait_stream(current_stream())
# Run shared experts in parallel on a separate stream # Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream): with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid shared_output = self.shared_experts(hidden_states_clone)
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone()) # Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: we dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment