Unverified Commit 4548c03c authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU][Bugfix] fix the MoE OOM issue (#20339)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent 40b86aa0
......@@ -1320,8 +1320,13 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment