Commit d1626ccc authored by Myle Ott's avatar Myle Ott
Browse files

Update FusedLayerNorm for new function API

parent 574fe244
...@@ -6,60 +6,66 @@ from torch.nn import init ...@@ -6,60 +6,66 @@ from torch.nn import init
from torch.nn import functional as F from torch.nn import functional as F
import importlib import importlib
class FusedLayerNormAffineFunction(torch.autograd.Function): global fused_layer_norm_cuda
def __init__(self, normalized_shape, eps=1e-6): fused_layer_norm_cuda = None
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
self.normalized_shape = normalized_shape class FusedLayerNormAffineFunction(torch.autograd.Function):
self.eps = eps
def forward(self, input, weight, bias): @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
weight_ = weight.contiguous() weight_ = weight.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine( output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, self.normalized_shape, weight_, bias_, self.eps) input_, ctx.normalized_shape, weight_, bias_, eps)
self.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
def backward(self, grad_output): @staticmethod
input_, weight_, bias_, mean, invvar = self.saved_tensors def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, self.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias; return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function): class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input): @staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward( output, mean, invvar = fused_layer_norm_cuda.forward(
input_, self.normalized_shape, self.eps) input_, ctx.normalized_shape, ctx.eps)
self.save_for_backward(input_, mean, invvar) ctx.save_for_backward(input_, mean, invvar)
return output return output
def backward(self, grad_output): @staticmethod
input_, mean, invvar = self.saved_tensors def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None grad_input = None
grad_input = fused_layer_norm_cuda.backward( grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape, input_, ctx.normalized_shape,
self.eps) ctx.eps)
return grad_input return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6): def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias) return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6): def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction(normalized_shape,eps)(input) return FusedLayerNormFunction.apply(input, normalized_shape, eps)
class FusedLayerNorm(torch.nn.Module): class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in r"""Applies Layer Normalization over a mini-batch of inputs as described in
...@@ -149,11 +155,10 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -149,11 +155,10 @@ class FusedLayerNorm(torch.nn.Module):
return F.layer_norm( return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps) input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine: if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)( return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias) input, self.weight, self.bias, self.normalized_shape,self.eps)
else: else:
return FusedLayerNormFunction(self.normalized_shape,self.eps)( return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
input)
def extra_repr(self): def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \ return '{normalized_shape}, eps={eps}, ' \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment