Commit 83602797 authored by Guolin Ke's avatar Guolin Ke
Browse files

typo

parent cf503760
......@@ -62,7 +62,7 @@ class LayerNorm(torch.nn.Module):
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
def fused_layer_norm(input):
if input.is_cuda():
if input.is_cuda:
return FusedLayerNormFastFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment