Commit d36b3c63 authored by Hubert Lu's avatar Hubert Lu
Browse files

Revert test_fused_layer_norm.py to prevent from missing torch.cuda.is_bf16_supported in pytorch 1.9

parent 93f3a3bc
import itertools
import unittest import unittest
import os
import random
import torch import torch
import apex import apex
from torch.autograd import Variable
class TestFusedLayerNorm(unittest.TestCase): class TestFusedLayerNorm(unittest.TestCase):
dtype = torch.float
elementwise_affine = False
normalized_shape = [32, 16]
rtol, atol = None, None
fwd_thresholds = dict(rtol=None, atol=None)
bwd_thresholds = dict(rtol=None, atol=None)
def setUp(self): def setUp(self):
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
self.module_cpu_ = apex.normalization.FusedLayerNorm( self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cpu()
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
self.module_cuda_ = apex.normalization.FusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype)
def _check_same_output(self, batch_size, contiguous): def _test_same_output(self, batch_size):
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
if contiguous: self.input_ = torch.randn((batch_size, *self.module_cpu_.normalized_shape), device="cpu").requires_grad_(True)
input_shape = [batch_size] + self.normalized_shape self.input_cuda_ = self.input_.cuda().detach().requires_grad_(True)
input_ = torch.randn(input_shape, device="cpu").requires_grad_(True) out_cpu_ = self.module_cpu_(self.input_)
input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True)
self.assertTrue(input_.is_contiguous())
self.assertTrue(input_cuda_.is_contiguous())
else:
input_shape = [batch_size] + self.normalized_shape
input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3]
input_src_ = torch.randn(input_shape, device="cpu")
input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)
input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True)
# make sure that tensors are NOT contiguous.
self.assertFalse(input_.is_contiguous())
self.assertFalse(input_cuda_.is_contiguous())
out_cpu_ = self.module_cpu_(input_)
gO = torch.rand_like(out_cpu_) gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO) out_cpu_.backward(gO)
out_cuda_ = self.module_cuda_(input_cuda_) out_cuda_ = self.module_cuda_(self.input_cuda_)
gO = gO.to(device="cuda", dtype=self.dtype) gO = gO.cuda()
out_cuda_.backward(gO) out_cuda_.backward(gO)
self.assertFalse(out_cpu_.is_cuda) assert out_cpu_.is_cuda == False
self.assertTrue(out_cuda_.is_cuda) assert out_cuda_.is_cuda == True
# TODO (mkozuki): `torch.testing.assert_allclose` is deprecated. torch.testing.assert_allclose(out_cpu_, out_cuda_.cpu())
# Use `torch.testing.assert_close`. torch.testing.assert_allclose(self.input_.grad, self.input_cuda_.grad.cpu())
# See https://github.com/pytorch/pytorch/issues/61844
torch.testing.assert_allclose(
out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_, **self.fwd_thresholds)
torch.testing.assert_allclose(
input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds)
def _test_same_output(self, batch_size):
for contiguous in (True, False):
with self.subTest(contiguous=contiguous):
self._check_same_output(batch_size, contiguous)
def test_layer_norm(self): def test_layer_norm(self):
self._test_same_output(16) self._test_same_output(16)
...@@ -67,105 +36,10 @@ class TestFusedLayerNorm(unittest.TestCase): ...@@ -67,105 +36,10 @@ class TestFusedLayerNorm(unittest.TestCase):
class TestFusedLayerNormElemWise(TestFusedLayerNorm): class TestFusedLayerNormElemWise(TestFusedLayerNorm):
elementwise_affine = True
class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
dtype = torch.half
def test_large_batch(self):
self.skipTest("Skip to save time")
# Megatron style Layer Norm
class TestFusedLayerNormElemWiseMixedDtypes(TestFusedLayerNorm):
def setUp(self): def setUp(self):
self.module_cpu_ = apex.normalization.MixedFusedLayerNorm( self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cpu()
normalized_shape=self.normalized_shape, elementwise_affine=True).cpu() self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cuda()
self.module_cuda_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=True).to(device="cuda", dtype=self.dtype)
def test_init_exception(self):
with self.assertRaisesRegex(RuntimeError, "MixedFusedLayerNorm does not support `elementwise_affine = False`"):
apex.normalization.MixedFusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
class TestFusedLayerNormElemWiseMixedDtypesHalf(TestFusedLayerNormElemWiseMixedDtypes):
dtype = torch.half
def test_large_batch(self):
self.skipTest("Skip to save time")
# NOTE (mkozuki): With the larger threshold values, still flaky.
class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMixedDtypesHalf):
dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
# Use thresholds larger than those used in pytorch, see
# https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# See [BFloat16 Layer Norm flakiness]
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
def _prep_layers(normalized_shape, elementwise_affine, dtype):
native = torch.nn.LayerNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).to(device="cuda", dtype=dtype)
fused = apex.normalization.FusedLayerNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
return native, fused
def _prep_inputs(batch_size, normalized_shape, dtype):
shape = (batch_size, *normalized_shape)
fused = torch.randn(shape).cuda().requires_grad_(True)
with torch.no_grad():
native = fused.clone().to(dtype).requires_grad_(True)
return native, fused
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestAutocastFusedLayerNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def setUp(self):
self.batch_size = 16
self.normalized_shape = [32, 16]
def _run_test(self, dtype, elementwise_affine):
native, fused = _prep_layers(self.normalized_shape, elementwise_affine, dtype)
native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype)
expected = native(native_x)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused(fused_x)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_fwd_thresholds
torch.testing.assert_allclose(actual, expected, **tols)
g_native = torch.rand_like(expected)
with torch.no_grad():
g_fused = g_native.clone()
expected.backward(g_native)
actual.backward(g_fused)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_bwd_thresholds
torch.testing.assert_allclose(native_x.grad, fused_x.grad, **tols)
def test_autocast(self): if __name__ == '__main__':
for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): unittest.main()
with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment