Unverified Commit 9d2d5612 authored by 杰兮's avatar 杰兮 Committed by GitHub
Browse files

[Bugfix] Fix precision corruption when shared_experts_stream=None (#28942)


Signed-off-by: default avatarzhyajie <yajizhan@amd.com>
Co-authored-by: default avatarzhyajie <yajizhan@amd.com>
parent fe69f331
...@@ -371,8 +371,8 @@ class FusedMoE(CustomOp): ...@@ -371,8 +371,8 @@ class FusedMoE(CustomOp):
logger.info_once("Disabling MoE shared_experts cuda stream") logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None self.shared_experts_stream = None
else: else:
# TODO(rob): enable shared expert overlap with non-cuda. # TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda platforms. # aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream() self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None: if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts") logger.info_once("Enabled separate cuda stream for MoE shared_experts")
...@@ -1865,6 +1865,11 @@ class FusedMoE(CustomOp): ...@@ -1865,6 +1865,11 @@ class FusedMoE(CustomOp):
hidden_states_combined, router_logits = get_ep_group().dispatch( hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel hidden_states, router_logits, self.is_sequence_parallel
) )
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
...@@ -1908,8 +1913,6 @@ class FusedMoE(CustomOp): ...@@ -1908,8 +1913,6 @@ class FusedMoE(CustomOp):
# conflict with the main stream # conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone) shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream) current_stream().wait_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
......
...@@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None: ...@@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None:
from vllm.platforms import current_platform from vllm.platforms import current_platform
# TODO: validate this works properly on ROCm platform. if _aux_stream is None and current_platform.is_cuda_alike():
if _aux_stream is None and current_platform.is_cuda():
_aux_stream = torch.cuda.Stream() _aux_stream = torch.cuda.Stream()
return _aux_stream return _aux_stream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment