Commit fff6432b authored by laibao's avatar laibao
Browse files

fix: input and input duplication feed into expert and shared_expert respectively.

parent 0019ecdc
......@@ -1268,7 +1268,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")),
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
}
......
......@@ -1483,6 +1483,7 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
......@@ -1500,9 +1501,9 @@ class FusedMoE(torch.nn.Module):
else:
if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, hidden_states_copy, router_logits)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, self.layer_name)
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, hidden_states_copy, self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......@@ -1623,11 +1624,14 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
if self.shared_experts_stream is not None and hidden_states_copy is not None:
hidden_states_copy.record_stream(self.shared_experts_stream)
use_shared_experts_stream, hidden_states_clone = self._maybe_setup_shared_experts_stream(hidden_states,
self.shared_experts is not None and self.shared_experts_stream is not None,
self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels)
......@@ -1707,7 +1711,8 @@ class FusedMoE(torch.nn.Module):
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states)
assert hidden_states_copy is not None
shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
......@@ -1822,11 +1827,12 @@ direct_register_custom_op(
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
out = self.forward_impl(hidden_states, router_logits)
out = self.forward_impl(hidden_states, router_logits, hidden_states_copy)
return out
......@@ -1834,6 +1840,7 @@ def moe_forward_shared(
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
......
......@@ -42,12 +42,14 @@ class SharedFusedMoE(FusedMoE):
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor
router_logits: torch.Tensor,
hidden_states_copy: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states = hidden_states,
router_logits = router_logits,
hidden_states_copy = hidden_states_copy
)
# # ensure early TP reduction of shared expert outputs when required
# if (
......
......@@ -244,7 +244,11 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits, _ = self.gate(hidden_states)
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)
hidden_states_copy = hidden_states.clone()
shared_output, final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy)
if self.shared_experts is None:
assert shared_output is None
......@@ -261,15 +265,6 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += shared_output
# if self.is_sequence_parallel:
# final_hidden_states = tensor_model_parallel_all_gather(
# final_hidden_states, 0
# )
# final_hidden_states = final_hidden_states[:num_tokens]
# elif self.tp_size > 1:
# final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
# final_hidden_states
# )
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment