Unverified Commit d63bdf78 authored by Jim Briggs's avatar Jim Briggs Committed by GitHub
Browse files

Add FP32 cast in ConvNext LayerNorm to prevent rounding errors with FP16 input (#18746)

* Adding cast to fp32 in convnext layernorm to prevent rounding errors in the case of fp16 input

* Trigger CI
parent 532ca050
......@@ -109,9 +109,12 @@ class ConvNextLayerNorm(nn.Module):
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment