Commit d130ec1f authored by Lam Dang's avatar Lam Dang
Browse files

quick fix: make FusedLayerNorm compatible with cpu

parent 683b6e0e
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import numbers import numbers
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F
import importlib import importlib
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
...@@ -144,6 +145,9 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -144,6 +145,9 @@ class FusedLayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, input): def forward(self, input):
if not input.is_cuda:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine: if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)( return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
input, self.weight, self.bias) input, self.weight, self.bias)
......
import unittest
import os
import random
import torch
import apex
class TestFusedLayerNorm(unittest.TestCase):
def setUp(self):
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=False)
self.input_ = torch.randn(16, 32, 64)
torch.cuda.manual_seed(42)
def forward_cpu(self, input_):
self.module.cpu()
return self.module(input_.cpu())
def forward_cuda(self, input_):
self.module.cuda()
return self.module(input_.cuda())
def test_forward_cuda(self):
out_ = self.forward_cuda(self.input_)
assert out_.is_cuda == True
def test_forward_cpu(self):
out_ = self.forward_cpu(self.input_)
assert out_.is_cuda == False
def test_same_output(self):
out_cpu = self.forward_cpu(self.input_)
out_cuda = self.forward_cuda(self.input_)
torch.testing.assert_allclose(out_cpu, out_cuda.cpu())
class TestFusedLayerNormElemWise(TestFusedLayerNorm):
def setUp(self):
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=True)
self.input_ = torch.randn(16, 32, 64)
torch.cuda.manual_seed(42)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment