Commit f4ce5889 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a possible type bug in layernorm

parent f913d977
...@@ -60,14 +60,14 @@ class LayerNorm(torch.nn.Module): ...@@ -60,14 +60,14 @@ class LayerNorm(torch.nn.Module):
self.reset_parameters() self.reset_parameters()
def torch_layer_norm(input): def torch_layer_norm(input):
return F.layer_norm( return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps) input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps)
def fused_layer_norm(input): def fused_layer_norm(input):
if input.is_cuda: if input.is_cuda:
return FusedLayerNormFastFunction.apply( return FusedLayerNormFastFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps) input, self.weight.type(input.dtype), self.bias.type(input.dtype), self.normalized_shape, self.eps)
else: else:
return F.layer_norm( return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps) input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps)
self.func = torch_layer_norm if (not HAS_LAYER_NORM or normalized_shape[0] not in FUSED_LAYER_NORM_SUPPORT_DIM) else fused_layer_norm self.func = torch_layer_norm if (not HAS_LAYER_NORM or normalized_shape[0] not in FUSED_LAYER_NORM_SUPPORT_DIM) else fused_layer_norm
def reset_parameters(self): def reset_parameters(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment