Unverified Commit e57d9e79 authored by Aidyn-A's avatar Aidyn-A Committed by GitHub
Browse files

[transformer] update tests (#1428)

parent 208d9670
from itertools import product from itertools import product
import unittest
import torch import torch
from torch.testing._internal import common_utils
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
...@@ -80,7 +80,7 @@ class MegatronPretrainingRandomSampler: ...@@ -80,7 +80,7 @@ class MegatronPretrainingRandomSampler:
# Samples 8 tensors in total. # Samples 8 tensors in total.
# First sample 4 tensors twice, then sample 2 tensors fourth. # First sample 4 tensors twice, then sample 2 tensors fourth.
class TestBatchSamplerBehavior(unittest.TestCase): class TestBatchSamplerBehavior(common_utils.TestCase):
def test_batch_sampler_behavior(self): def test_batch_sampler_behavior(self):
dataset = MyIterableDataset(0, 100) dataset = MyIterableDataset(0, 100)
...@@ -101,7 +101,7 @@ class TestBatchSamplerBehavior(unittest.TestCase): ...@@ -101,7 +101,7 @@ class TestBatchSamplerBehavior(unittest.TestCase):
samples2.append(batch) samples2.append(batch)
if i == 4 - 1: if i == 4 - 1:
break break
torch.testing.assert_close(torch.cat(samples), torch.cat(samples2)) self.assertEqual(torch.cat(samples), torch.cat(samples2))
def test_split_batch(self): def test_split_batch(self):
...@@ -139,4 +139,4 @@ class TestBatchSamplerBehavior(unittest.TestCase): ...@@ -139,4 +139,4 @@ class TestBatchSamplerBehavior(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() common_utils.run_tests()
...@@ -80,8 +80,8 @@ class VocabParallelCrossEntropyTestBase: ...@@ -80,8 +80,8 @@ class VocabParallelCrossEntropyTestBase:
batch_size, sequence_length, vocab_size, logits_scale, seed batch_size, sequence_length, vocab_size, logits_scale, seed
) )
torch.testing.assert_close(loss_torch, loss_tensor_parallel) self.assertEqual(loss_torch, loss_tensor_parallel)
torch.testing.assert_close(grad_torch, grad_tensor_parallel) self.assertEqual(grad_torch, grad_tensor_parallel)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
...@@ -51,7 +51,7 @@ class BroadcastDataTestBase: ...@@ -51,7 +51,7 @@ class BroadcastDataTestBase:
broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys: for key in keys:
torch.testing.assert_close(broadcasted_data[key], data_t[key].cuda()) self.assertEqual(broadcasted_data[key], data_t[key].cuda())
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
""" # NOQA """ # NOQA
import itertools import itertools
import unittest
import torch import torch
from torch.testing._internal import common_utils
from apex.transformer import AttnMaskType from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax from apex.transformer.functional import FusedScaleMaskSoftmax
...@@ -20,7 +20,7 @@ autocast_dtypes = ( ...@@ -20,7 +20,7 @@ autocast_dtypes = (
) )
class TestFusedScaleMaskSoftmax(unittest.TestCase): class TestFusedScaleMaskSoftmax(common_utils.TestCase):
def _setup_fused_softmax( def _setup_fused_softmax(
self, self,
input_in_fp16, input_in_fp16,
...@@ -89,7 +89,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
mask = torch.randint(0, 2, mask_shape, device="cuda").bool() mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
expected = fused_fn(attention_scores_0, mask) expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask) actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_close(actual, expected) self.assertEqual(actual, expected)
g0 = torch.rand_like(actual) g0 = torch.rand_like(actual)
with torch.no_grad(): with torch.no_grad():
...@@ -119,7 +119,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
with torch.cuda.amp.autocast(dtype=dtype): with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores_0, mask) actual = fused_fn(attention_scores_0, mask)
self.assertEqual(actual.dtype, dtype) self.assertEqual(actual.dtype, dtype)
torch.testing.assert_close(actual, expected) self.assertEqual(actual, expected)
g0 = torch.rand_like(actual) g0 = torch.rand_like(actual)
with torch.no_grad(): with torch.no_grad():
...@@ -174,7 +174,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -174,7 +174,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
total_mask = total_mask.repeat((4, 1, 1, 1)) total_mask = total_mask.repeat((4, 1, 1, 1))
expected = fused_fn(attn_weights_0, total_mask) expected = fused_fn(attn_weights_0, total_mask)
actual = torch_fn(attn_weights_1, total_mask) actual = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_close(actual, expected) self.assertEqual(actual, expected)
g0 = torch.randn_like(actual) g0 = torch.randn_like(actual)
with torch.no_grad(): with torch.no_grad():
...@@ -208,10 +208,13 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -208,10 +208,13 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
actual = fused_fn(attn_weights_0, total_mask) actual = fused_fn(attn_weights_0, total_mask)
self.assertEqual(actual.dtype, dtype) self.assertEqual(actual.dtype, dtype)
expected = torch_fn(attn_weights_1, total_mask) expected = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_close(actual, expected) self.assertEqual(actual, expected)
g0 = torch.randn_like(actual) g0 = torch.randn_like(actual)
with torch.no_grad(): with torch.no_grad():
g1 = g0.clone() g1 = g0.clone()
actual.backward(g0) actual.backward(g0)
expected.backward(g1) expected.backward(g1)
if __name__ == "__main__":
common_utils.run_tests()
...@@ -82,7 +82,7 @@ class TensorParallelLayerTestBase: ...@@ -82,7 +82,7 @@ class TensorParallelLayerTestBase:
group=parallel_state.get_tensor_model_parallel_group(), group=parallel_state.get_tensor_model_parallel_group(),
) )
torch.testing.assert_close(gathered, gathered_for_base) self.assertEqual(gathered, gathered_for_base)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
@torch.no_grad() @torch.no_grad()
...@@ -130,8 +130,8 @@ class TensorParallelLayerTestBase: ...@@ -130,8 +130,8 @@ class TensorParallelLayerTestBase:
group=parallel_state.get_tensor_model_parallel_group(), group=parallel_state.get_tensor_model_parallel_group(),
) )
torch.testing.assert_close(output, output_for_base) self.assertEqual(output, output_for_base)
torch.testing.assert_close(input, torch.cat(input_list)) self.assertEqual(input, torch.cat(input_list))
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
def test_parallel_embedding(self) -> None: def test_parallel_embedding(self) -> None:
...@@ -376,25 +376,25 @@ class TensorParallelLayerTestBase: ...@@ -376,25 +376,25 @@ class TensorParallelLayerTestBase:
if not accumulation_in_fp16: if not accumulation_in_fp16:
if sequence_parallel_enabled: if sequence_parallel_enabled:
torch.testing.assert_close( self.assertEqual(
actual=output, x=output,
expected=expected_output.chunk( y=expected_output.chunk(
chunks=tensor_model_parallel_world_size, chunks=tensor_model_parallel_world_size,
dim=0, dim=0,
)[parallel_state.get_tensor_model_parallel_rank()], )[parallel_state.get_tensor_model_parallel_rank()],
) )
else: else:
torch.testing.assert_close( self.assertEqual(
actual=output, x=output,
expected=expected_output, y=expected_output,
) )
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1: if tensor_model_parallel_world_size == 1:
torch.testing.assert_close( self.assertEqual(
actual=getattr(linear.weight, grad_attr_name), x=getattr(linear.weight, grad_attr_name),
expected=ref_linear.weight.grad.chunk( y=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size, chunks=tensor_model_parallel_world_size,
dim=0, dim=0,
)[parallel_state.get_tensor_model_parallel_rank()], )[parallel_state.get_tensor_model_parallel_rank()],
...@@ -520,14 +520,14 @@ class TensorParallelLayerTestBase: ...@@ -520,14 +520,14 @@ class TensorParallelLayerTestBase:
tensor_model_parallel_world_size, tensor_model_parallel_world_size,
dim=2, dim=2,
)[parallel_state.get_tensor_model_parallel_rank()] )[parallel_state.get_tensor_model_parallel_rank()]
torch.testing.assert_close( self.assertEqual(
actual=output, x=output,
expected=chunk, y=chunk,
) )
else: else:
torch.testing.assert_close( self.assertEqual(
actual=output, x=output,
expected=expected_output, y=expected_output,
) )
expected_loss = torch.mul(expected_output, dldy).sum() expected_loss = torch.mul(expected_output, dldy).sum()
...@@ -535,9 +535,9 @@ class TensorParallelLayerTestBase: ...@@ -535,9 +535,9 @@ class TensorParallelLayerTestBase:
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1: if tensor_model_parallel_world_size == 1:
torch.testing.assert_close( self.assertEqual(
actual=getattr(linear.weight, grad_attr_name), x=getattr(linear.weight, grad_attr_name),
expected=ref_linear.weight.grad.chunk( y=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size, chunks=tensor_model_parallel_world_size,
dim=0, dim=0,
)[parallel_state.get_tensor_model_parallel_rank()], )[parallel_state.get_tensor_model_parallel_rank()],
......
...@@ -205,7 +205,7 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -205,7 +205,7 @@ class PipelineParallelForwardBackwardTestBase:
for loss_item in loss: for loss_item in loss:
x = loss_item['avg'] x = loss_item['avg']
torch.testing.assert_close(x.item() / microbatch_size, target_loss.item()) self.assertEqual(x.item() / microbatch_size, target_loss.item())
if not forward_only: if not forward_only:
for vm_id, model_module in enumerate(model): for vm_id, model_module in enumerate(model):
...@@ -215,10 +215,10 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -215,10 +215,10 @@ class PipelineParallelForwardBackwardTestBase:
param_id = rank // data_parallel_size + vm_id * offset param_id = rank // data_parallel_size + vm_id * offset
target_params = target_model[param_id] target_params = target_model[param_id]
torch.testing.assert_close(params[0].cpu(), target_params[0]) self.assertEqual(params[0].cpu(), target_params[0])
torch.testing.assert_close(params[1].cpu(), target_params[1]) self.assertEqual(params[1].cpu(), target_params[1])
torch.testing.assert_close(params[0].grad.cpu() / microbatch_size, target_params[0].grad) self.assertEqual(params[0].grad.cpu() / microbatch_size, target_params[0].grad)
torch.testing.assert_close(params[1].grad.cpu() / microbatch_size, target_params[1].grad) self.assertEqual(params[1].grad.cpu() / microbatch_size, target_params[1].grad)
if not forward_only: if not forward_only:
for m in model: for m in model:
......
...@@ -52,7 +52,7 @@ class TransformerRandomTestBase: ...@@ -52,7 +52,7 @@ class TransformerRandomTestBase:
torch.randn(size, out=tensor) torch.randn(size, out=tensor)
result_2 = tensor.clone() result_2 = tensor.clone()
torch.testing.assert_close(result_2, result_1) self.assertEqual(result_2, result_1)
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0) self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment