Commit fff6432b authored by laibao's avatar laibao
Browse files

fix: input and input duplication feed into expert and shared_expert respectively.

parent 0019ecdc
...@@ -1268,7 +1268,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1268,7 +1268,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
), ),
} }
......
...@@ -1483,6 +1483,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1483,6 +1483,7 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
...@@ -1500,9 +1501,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1500,9 +1501,9 @@ class FusedMoE(torch.nn.Module):
else: else:
if current_platform.is_tpu(): if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now" assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, hidden_states_copy, router_logits)
else: else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, self.layer_name) return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, hidden_states_copy, self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1623,11 +1624,14 @@ class FusedMoE(torch.nn.Module): ...@@ -1623,11 +1624,14 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
if self.shared_experts_stream is not None and hidden_states_copy is not None:
hidden_states_copy.record_stream(self.shared_experts_stream)
use_shared_experts_stream, hidden_states_clone = self._maybe_setup_shared_experts_stream(hidden_states, use_shared_experts_stream, hidden_states_clone = self._maybe_setup_shared_experts_stream(hidden_states,
self.shared_experts is not None and self.shared_experts_stream is not None, self.shared_experts is not None and self.shared_experts_stream is not None,
self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels) self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels)
...@@ -1707,7 +1711,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1707,7 +1711,8 @@ class FusedMoE(torch.nn.Module):
with torch.cuda.stream(self.shared_experts_stream): with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid # Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream # conflict with the main stream
shared_output = self.shared_experts(hidden_states) assert hidden_states_copy is not None
shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream) torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
...@@ -1822,18 +1827,20 @@ direct_register_custom_op( ...@@ -1822,18 +1827,20 @@ direct_register_custom_op(
def moe_forward_shared( def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]: layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
out = self.forward_impl(hidden_states, router_logits) out = self.forward_impl(hidden_states, router_logits, hidden_states_copy)
return out return out
def moe_forward_shared_fake( def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]: layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
......
...@@ -42,12 +42,14 @@ class SharedFusedMoE(FusedMoE): ...@@ -42,12 +42,14 @@ class SharedFusedMoE(FusedMoE):
def forward(self, def forward(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor router_logits: torch.Tensor,
hidden_states_copy: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = super().forward( shared_out, fused_out = super().forward(
hidden_states=hidden_states, hidden_states = hidden_states,
router_logits=router_logits, router_logits = router_logits,
hidden_states_copy = hidden_states_copy
) )
# # ensure early TP reduction of shared expert outputs when required # # ensure early TP reduction of shared expert outputs when required
# if ( # if (
......
...@@ -244,7 +244,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -244,7 +244,11 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) hidden_states_copy = hidden_states.clone()
shared_output, final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy)
if self.shared_experts is None: if self.shared_experts is None:
assert shared_output is None assert shared_output is None
...@@ -261,15 +265,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -261,15 +265,6 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
# if self.is_sequence_parallel:
# final_hidden_states = tensor_model_parallel_all_gather(
# final_hidden_states, 0
# )
# final_hidden_states = final_hidden_states[:num_tokens]
# elif self.tp_size > 1:
# final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
# final_hidden_states
# )
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = self.tbo_all_reduce(final_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment