Commit bda70b4b authored by puck_WCR's avatar puck_WCR Committed by binmakeswell
Browse files

[NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980)

parent c50c08dc
...@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_ = input.contiguous() input_ = input.contiguous()
weight_ = weight.contiguous() weight_ = weight.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
output, mean, invvar = colossal_layer_norm_cuda.forward_affine( output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
...@@ -72,8 +72,7 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -72,8 +72,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
self.normalized_shape, self.eps)
def __repr__(self): def __repr__(self):
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment