Unverified Commit 3d01e4a0 authored by ngimel's avatar ngimel Committed by GitHub
Browse files

Merge pull request #400 from myleott/new_function_api

Update FusedLayerNorm for new function API
parents 574fe244 0dbf6c2a
......@@ -6,60 +6,66 @@ from torch.nn import init
from torch.nn import functional as F
import importlib
class FusedLayerNormAffineFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
self.normalized_shape = normalized_shape
self.eps = eps
class FusedLayerNormAffineFunction(torch.autograd.Function):
def forward(self, input, weight, bias):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, self.normalized_shape, weight_, bias_, self.eps)
self.save_for_backward(input_, weight_, bias_, mean, invvar)
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
def backward(self, grad_output):
input_, weight_, bias_, mean, invvar = self.saved_tensors
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
weight_, bias_, self.eps)
return grad_input, grad_weight, grad_bias;
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, self.normalized_shape, self.eps)
self.save_for_backward(input_, mean, invvar)
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
def backward(self, grad_output):
input_, mean, invvar = self.saved_tensors
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
self.eps)
return grad_input
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction(normalized_shape,eps)(input)
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
......@@ -149,11 +155,10 @@ class FusedLayerNorm(torch.nn.Module):
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
input, self.weight, self.bias)
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
else:
return FusedLayerNormFunction(self.normalized_shape,self.eps)(
input)
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment