Unverified Commit d63bdf78 authored by Jim Briggs's avatar Jim Briggs Committed by GitHub
Browse files

Add FP32 cast in ConvNext LayerNorm to prevent rounding errors with FP16 input (#18746)

* Adding cast to fp32 in convnext layernorm to prevent rounding errors in the case of fp16 input

* Trigger CI
parent 532ca050
...@@ -109,9 +109,12 @@ class ConvNextLayerNorm(nn.Module): ...@@ -109,9 +109,12 @@ class ConvNextLayerNorm(nn.Module):
if self.data_format == "channels_last": if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first": elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True) u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps) x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None] x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment