Commit 5a4a3b77 authored by Jiang Zhuo's avatar Jiang Zhuo Committed by Frank Lee
Browse files

fix format (#376)

parent ce886a90
...@@ -7,7 +7,7 @@ except: ...@@ -7,7 +7,7 @@ except:
class FusedLayerNormAffineFunction1D(torch.autograd.Function): class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r""" r"""
Layernorm Layernorm
:param input: input maxtrix :param input: input maxtrix
...@@ -20,27 +20,26 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): ...@@ -20,27 +20,26 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
:param eps: a value added to the denominator for numerical stability :param eps: a value added to the denominator for numerical stability
""" """
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.eps = eps ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
weight_ = weight.contiguous() weight_ = weight.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
@staticmethod
@staticmethod def backward(ctx, grad_output):
def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors
input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None
grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \
grad_input, grad_weight, grad_bias \ = fused_mix_prec_layer_norm_cuda.backward_affine(
= fused_mix_prec_layer_norm_cuda.backward_affine( grad_output.contiguous(), mean, invvar,
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape,
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None
\ No newline at end of file
...@@ -81,6 +81,7 @@ class _ReduceGrad(torch.autograd.Function): ...@@ -81,6 +81,7 @@ class _ReduceGrad(torch.autograd.Function):
:param input_: input matrix :param input_: input matrix
:param parallel_mode: parallel mode :param parallel_mode: parallel mode
""" """
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return input_ return input_
...@@ -102,6 +103,7 @@ class _ReduceInput(torch.autograd.Function): ...@@ -102,6 +103,7 @@ class _ReduceInput(torch.autograd.Function):
:param input_: input matrix :param input_: input matrix
:param parallel_mode: parallel mode :param parallel_mode: parallel mode
""" """
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _reduce(input_) return _reduce(input_)
...@@ -123,6 +125,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function): ...@@ -123,6 +125,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
:param parallel_mode: parallel mode :param parallel_mode: parallel mode
:param dim: dimension :param dim: dimension
""" """
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _split(input_) return _split(input_)
...@@ -146,6 +149,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function): ...@@ -146,6 +149,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
:param parallel_mode: parallel mode :param parallel_mode: parallel mode
:param dim: dimension :param dim: dimension
""" """
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _gather(input_) return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment