Unverified Commit 2196b766 authored by Arturo Ghinassi's avatar Arturo Ghinassi Committed by GitHub
Browse files

convert input to same type as weight for mixed precision training

When using CUDA AMP FMoE Linear throws type error as input is half() and weights are float()
parent df715c9f
......@@ -70,7 +70,7 @@ class FMoELinear(nn.Module):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
x = MOELinear.apply(inp.float(), fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment