Unverified Commit 6936be32 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remve router gemm output dtype conversion (#8204)

parent 9b5de6cb
...@@ -254,9 +254,8 @@ class MoEGate(nn.Module): ...@@ -254,9 +254,8 @@ class MoEGate(nn.Module):
and self.weight.shape[0] == 256 and self.weight.shape[0] == 256
and _device_sm >= 90 and _device_sm >= 90
): ):
logits = dsv3_router_gemm(hidden_states, self.weight).to( # router gemm output float32
hidden_states.dtype logits = dsv3_router_gemm(hidden_states, self.weight)
)
else: else:
logits = F.linear(hidden_states, self.weight, None) logits = F.linear(hidden_states, self.weight, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment