Unverified Commit 4548c03c authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU][Bugfix] fix the MoE OOM issue (#20339)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent 40b86aa0
...@@ -1320,6 +1320,11 @@ class FusedMoE(torch.nn.Module): ...@@ -1320,6 +1320,11 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment