Commit 3ef01fae authored by ngimel's avatar ngimel Committed by mcarilli
Browse files

Clean up layer norm tests (#418)

* Bug fix for non-affine layer-norm + add backward unit test

* clean up tests and add tests for a large batch
parent 37795aac
...@@ -795,11 +795,13 @@ void cuda_layer_norm_gradient( ...@@ -795,11 +795,13 @@ void cuda_layer_norm_gradient(
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input, input,
n1,n2, n1,n2,
gamma->data<scalar_t_0>(), // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
beta->data<scalar_t_0>(), // if gamma Tensor is NULL on input.
gamma != NULL ? gamma->data<scalar_t_0>() : NULL,
gamma != NULL ? beta->data<scalar_t_0>() : NULL,
epsilon, epsilon,
grad_input->data<scalar_t_0>(), grad_input->data<scalar_t_0>(),
grad_gamma->data<scalar_t_0>(), gamma != NULL ? grad_gamma->data<scalar_t_0>() : NULL,
grad_beta->data<scalar_t_0>()); gamma != NULL ? grad_beta->data<scalar_t_0>() : NULL);
) )
} }
...@@ -4,38 +4,39 @@ import random ...@@ -4,38 +4,39 @@ import random
import torch import torch
import apex import apex
from torch.autograd import Variable
class TestFusedLayerNorm(unittest.TestCase): class TestFusedLayerNorm(unittest.TestCase):
def setUp(self): def setUp(self):
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=False) # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
self.input_ = torch.randn(16, 32, 64) self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cpu()
self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
def _test_same_output(self, batch_size):
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
self.input_ = torch.randn((batch_size, *self.module_cpu_.normalized_shape), device="cpu").requires_grad_(True)
def forward_cpu(self, input_): self.input_cuda_ = self.input_.cuda().detach().requires_grad_(True)
self.module.cpu() out_cpu_ = self.module_cpu_(self.input_)
return self.module(input_.cpu()) gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO)
def forward_cuda(self, input_): out_cuda_ = self.module_cuda_(self.input_cuda_)
self.module.cuda() gO = gO.cuda()
return self.module(input_.cuda()) out_cuda_.backward(gO)
assert out_cpu_.is_cuda == False
def test_forward_cuda(self): assert out_cuda_.is_cuda == True
out_ = self.forward_cuda(self.input_) torch.testing.assert_allclose(out_cpu_, out_cuda_.cpu())
assert out_.is_cuda == True torch.testing.assert_allclose(self.input_.grad, self.input_cuda_.grad.cpu())
def test_forward_cpu(self): def test_layer_norm(self):
out_ = self.forward_cpu(self.input_) self._test_same_output(16)
assert out_.is_cuda == False
def test_large_batch(self):
def test_same_output(self): self._test_same_output(65536)
out_cpu = self.forward_cpu(self.input_)
out_cuda = self.forward_cuda(self.input_)
torch.testing.assert_allclose(out_cpu, out_cuda.cpu())
class TestFusedLayerNormElemWise(TestFusedLayerNorm): class TestFusedLayerNormElemWise(TestFusedLayerNorm):
def setUp(self): def setUp(self):
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=True) self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cpu()
self.input_ = torch.randn(16, 32, 64) self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cuda()
torch.cuda.manual_seed(42)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment