Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
import torch
from apex.transformer.tensor_parallel import utils
def test_divide():
assert utils.divide(8, 4) == 2
def test_split_tensor_along_last_dim():
inputy = torch.randn((100, 100, 100))
splits = utils.split_tensor_along_last_dim(inputy, 10)
last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits])
assert torch.equal(last_dim_shapes, torch.full((10,), 10))
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
test_divide()
test_split_tensor_along_last_dim()
print(">> passed the test :-)")
from itertools import product
import unittest
import torch
from torch.testing._internal import common_utils
from torch.utils.data import Dataset
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
......@@ -80,7 +80,7 @@ class MegatronPretrainingRandomSampler:
# Samples 8 tensors in total.
# First sample 4 tensors twice, then sample 2 tensors fourth.
class TestBatchSamplerBehavior(unittest.TestCase):
class TestBatchSamplerBehavior(common_utils.TestCase):
def test_batch_sampler_behavior(self):
dataset = MyIterableDataset(0, 100)
......@@ -101,7 +101,7 @@ class TestBatchSamplerBehavior(unittest.TestCase):
samples2.append(batch)
if i == 4 - 1:
break
torch.testing.assert_allclose(torch.cat(samples), torch.cat(samples2))
self.assertEqual(torch.cat(samples), torch.cat(samples2))
def test_split_batch(self):
......@@ -127,11 +127,6 @@ class TestBatchSamplerBehavior(unittest.TestCase):
global_batch_size = 16
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
batch = next(iter(loader))
# samples = None
# for i, batch in enumerate(loader):
# # samples = batch
# if i == 0:
# break
for _micro_batch_size in (1, 2, 4, 8):
microbatches = list(split_batch_into_microbatch(
......@@ -139,11 +134,9 @@ class TestBatchSamplerBehavior(unittest.TestCase):
_micro_batch_size=_micro_batch_size,
_global_batch_size=global_batch_size,
))
# print(batch)
# print(microbatches)
self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
if __name__ == "__main__":
unittest.main()
common_utils.run_tests()
import logging
from typing import Tuple
import torch
import torch.nn.functional as F
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel import cross_entropy
from apex.transformer.testing.commons import set_random_seed, IdentityLayer
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
def torch_cross_entropy(
batch_size: int, seq_length: int, vocab_size: int, logits_scale: float, seed: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
set_random_seed(seed)
identity = IdentityLayer(
(batch_size, seq_length, vocab_size), scale=logits_scale
).cuda()
logits = identity()
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
loss = (
F.cross_entropy(
logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none"
)
.view_as(target)
.mean()
)
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(
batch_size, seq_length, vocab_size, logits_scale, seed
):
set_random_seed(seed)
identity = IdentityLayer(
(batch_size, seq_length, vocab_size), scale=logits_scale
).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
logits_parallel_ = logits_parallel.clone().detach()
loss = cross_entropy.vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
# check for mutation
assert torch.equal(logits_parallel_, logits_parallel)
return loss, identity.weight.grad
class VocabParallelCrossEntropyTestBase:
def test_cross_entropy(self):
batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
logits_scale = 1000.0
seed = 1234
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size
loss_torch, grad_torch = torch_cross_entropy(
batch_size, sequence_length, vocab_size, logits_scale, seed
)
(
loss_tensor_parallel,
grad_tensor_parallel,
) = tensor_sharded_cross_entropy(
batch_size, sequence_length, vocab_size, logits_scale, seed
)
self.assertEqual(loss_torch, loss_tensor_parallel)
self.assertEqual(grad_torch, grad_tensor_parallel)
parallel_state.destroy_model_parallel()
class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): pass
class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch.testing
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
class BroadcastDataTestBase:
def test_broadcast_data(self):
tensor_model_parallel_world_size: int = self.world_size // (
1 + self.world_size > 1
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
target_key_size = {
"key1": [7, 11],
"key2": [8, 2, 1],
"key3": [13],
"key4": [5, 1, 2],
"key5": [5, 12],
}
keys = [k for k in target_key_size]
data = {}
data_t = {}
with torch.no_grad():
for key in target_key_size:
data[key] = torch.randint(0, 1000, size=target_key_size[key])
data_t[key] = data[key].clone()
# "key_x" is supposed to be ignored.
data["key_x"] = torch.rand(5)
data_t["key_x"] = data["key_x"].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
self.assertEqual(target_key_size[key], key_size[key])
broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
self.assertEqual(broadcasted_data[key], data_t[key].cuda())
parallel_state.destroy_model_parallel()
class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass
class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -3,9 +3,9 @@
Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
""" # NOQA
import itertools
import unittest
import torch
from torch.testing._internal import common_utils
from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax
......@@ -15,12 +15,20 @@ def attention_mask_func(attention_scores, attention_mask):
return attention_scores.masked_fill(attention_mask, -10000.0)
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
autocast_dtypes = (
(torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
class TestFusedScaleMaskSoftmax(unittest.TestCase):
def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding):
class TestFusedScaleMaskSoftmax(common_utils.TestCase):
def _setup_fused_softmax(
self,
input_in_fp16,
input_in_bf16,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
):
fused_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16,
......@@ -46,27 +54,42 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
attention_scores.shape = [4, 12, 24, 24]
mask.shape = [4, 1, 24, 24]
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
(torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
self._setup_fused_softmax(
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.padding,
)
return
fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
attention_scores_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.padding,
)
attention_scores_0 = (
torch.randn(shape)
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
mask_shape = (shape[0],) + (1,) + shape[2:]
mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_allclose(actual, expected)
self.assertEqual(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
......@@ -80,18 +103,23 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding)
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
)
attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
attention_scores_0 = (
torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True)
attention_scores_1 = (
attention_scores_0.clone().to(dtype).requires_grad_(True)
)
mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
expected = torch_fn(attention_scores_1, mask)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores_0, mask)
self.assertEqual(actual.dtype, dtype)
torch.testing.assert_allclose(actual, expected)
self.assertEqual(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
......@@ -108,9 +136,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
upper elements are True and lower elements and diagonal are False.
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
(torch.half, torch.bfloat16), (None, 2.0), (False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
......@@ -118,21 +144,37 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.causal,
)
return
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
attn_weights_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.causal,
)
attn_weights_0 = (
torch.randn((4, 12, 24, 24))
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
total_mask = (
~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
.unsqueeze(0)
.unsqueeze(0)
)
total_mask = total_mask.repeat((4, 1, 1, 1))
expected = fused_fn(attn_weights_0, total_mask)
actual = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
self.assertEqual(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
......@@ -146,23 +188,33 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal)
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
)
attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
attn_weights_0 = (
torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
attn_weights_1 = (
attn_weights_0.clone().to(dtype).requires_grad_(True)
)
total_mask = (
~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
.unsqueeze(0)
.unsqueeze(0)
)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attn_weights_0, total_mask)
self.assertEqual(actual.dtype, dtype)
expected = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
self.assertEqual(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
g1 = g0.clone()
actual.backward(g0)
expected.backward(g1)
if __name__ == "__main__":
common_utils.run_tests()
import logging
import unittest
import typing
import torch
import torch.nn as nn
from torch.testing._internal import common_utils
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
# N.B.(mkozuki): Disable TF32 matrix multiply.
# Matrices used in this test are so small that TF32 matmul
# can be less precise so that `self.assertEqual` raises.
torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTestBase:
BATCH_SIZE: int = 8
SEQUENCE_LENGTH: int = 128
VOCAB_SIZE: int = 1024
HIDDEN_SIZE: int = 256
INPUT_SIZE_COEFF: int = 256
OUTPUT_SIZE_COEFF: int = 256
SEED: int = 123456
@property
def tensor_shape(self) -> typing.Sequence[int]:
return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE]
@torch.no_grad()
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
def test_all_gather_parity(self) -> None:
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15")
from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
with torch.no_grad():
tensor = tensor_model_parallel_rank * torch.ones(
self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
numel = tensor.numel()
numel_gathered = tensor_model_parallel_world_size * numel
gathered = torch.empty(
torch.Size((numel_gathered,)),
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
chunks = [
gathered[i * numel : (i + 1) * numel]
for i in range(tensor_model_parallel_world_size)
]
all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
gathered_for_base = torch.empty(
torch.Size((numel_gathered,)),
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
_all_gather_base(
gathered_for_base,
tensor,
group=parallel_state.get_tensor_model_parallel_group(),
)
self.assertEqual(gathered, gathered_for_base)
parallel_state.destroy_model_parallel()
@torch.no_grad()
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
def test_reduce_scatter_parity(self) -> None:
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15")
from torch.distributed.distributed_c10d import reduce_scatter, _reduce_scatter_base # NOQA
for tensor_model_parallel_world_size in range(2, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
with torch.no_grad():
input = torch.cat([
i * torch.ones(self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
for i in range(tensor_model_parallel_world_size)
])
input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)]
output = torch.empty(
self.tensor_shape,
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
reduce_scatter(
output, input_list,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_for_base = torch.empty(
self.tensor_shape,
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
_reduce_scatter_base(
output_for_base,
input,
group=parallel_state.get_tensor_model_parallel_group(),
)
self.assertEqual(output, output_for_base)
self.assertEqual(input, torch.cat(input_list))
parallel_state.destroy_model_parallel()
def test_parallel_embedding(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(self.SEED + 1)
input_tensor = torch.randint(
0,
self.VOCAB_SIZE,
(
self.BATCH_SIZE,
self.SEQUENCE_LENGTH,
),
device="cuda",
)
loss_weight = torch.randn(
(
self.BATCH_SIZE,
self.SEQUENCE_LENGTH,
self.HIDDEN_SIZE,
),
device="cuda",
)
set_random_seed(self.SEED)
embedding_torch = nn.Embedding(
self.VOCAB_SIZE,
self.HIDDEN_SIZE,
).cuda()
output_torch = embedding_torch(input_tensor)
loss_torch = torch.mul(output_torch, loss_weight).sum()
loss_torch.backward()
# N.B.(mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`.
set_random_seed(self.SEED)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
self.VOCAB_SIZE,
self.HIDDEN_SIZE,
init_method=nn.init.normal_,
use_cpu_initialization=True,
).cuda()
output_vocab_parallel = embedding_vocab_parallel(input_tensor)
loss_vocab_parallel = torch.mul(
output_vocab_parallel, loss_weight
).sum()
loss_vocab_parallel.backward()
self.assertEqual(output_torch, output_vocab_parallel)
self.assertEqual(loss_torch, loss_vocab_parallel)
splitted_weight_torch = torch.split(
embedding_torch.weight.grad,
self.VOCAB_SIZE
// tensor_model_parallel_world_size,
0,
)[parallel_state.get_tensor_model_parallel_rank()]
self.assertEqual(
splitted_weight_torch, embedding_vocab_parallel.weight.grad
)
parallel_state.destroy_model_parallel()
def _affine_weight_init_test_impl(
self, init_device: str, is_column_parallel: bool
) -> None:
dim = int(not is_column_parallel)
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
weight_shape = (
(self.OUTPUT_SIZE_COEFF, input_size)
if is_column_parallel
else (output_size, self.INPUT_SIZE_COEFF)
)
weight = torch.empty(weight_shape)
set_random_seed(self.SEED)
sharding_dim_size = (
self.OUTPUT_SIZE_COEFF
if is_column_parallel
else self.INPUT_SIZE_COEFF
)
if init_device == "cpu":
layers._initialize_affine_weight_cpu(
weight,
output_size,
input_size,
sharding_dim_size,
dim,
nn.init.normal_,
params_dtype=torch.float32,
)
else:
layers._initialize_affine_weight_gpu(
weight, torch.nn.init.normal_, dim
)
# Target
set_random_seed(self.SEED)
if init_device == "cpu":
main_weight = torch.empty(output_size, input_size)
nn.init.normal_(main_weight)
curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[
parallel_state.get_tensor_model_parallel_rank()
]
else:
curr_weight = torch.empty(*weight_shape)
nn.init.normal_(curr_weight)
self.assertEqual(curr_weight, weight)
parallel_state.destroy_model_parallel()
def test_affine_weight_init_column_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True)
def test_affine_weight_init_column_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True)
def test_affine_weight_init_row_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False)
def test_affine_weight_init_row_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)
def test_row_parallel_linear(self) -> None:
self._row_parallel_linear_test_impl(False, False, False)
def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None:
self._row_parallel_linear_test_impl(True, False, False)
def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None:
self._row_parallel_linear_test_impl(True, True, False)
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs")
def test_row_parallel_linear_sequence_parallel(self) -> None:
self._row_parallel_linear_test_impl(False, False, True)
# TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl`
# Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated.
def _row_parallel_linear_test_impl(
self,
gradient_accumulation_fusion: bool,
accumulation_in_fp16: bool,
sequence_parallel_enabled: bool,
) -> None:
tensor_shape = (
self.SEQUENCE_LENGTH,
self.BATCH_SIZE,
self.HIDDEN_SIZE,
)
for tensor_model_parallel_world_size in range(
1 + int(sequence_parallel_enabled), self.world_size + 1
):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size,
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(self.SEED)
linear = layers.RowParallelLinear(
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
gradient_accumulation_fusion=gradient_accumulation_fusion,
accumulation_in_fp16=accumulation_in_fp16,
sequence_parallel_enabled=sequence_parallel_enabled,
# n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True`
# by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\
# db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204
input_is_parallel=True,
).cuda()
if accumulation_in_fp16:
linear = linear.half()
# Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled.
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.zeros_like(linear.weight)
with torch.no_grad():
orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda")
orig_loss_weight = torch.randn(tensor_shape, device="cuda")
input_tensor = orig_input_tensor.chunk(
chunks=tensor_model_parallel_world_size,
dim=2,
)[parallel_state.get_tensor_model_parallel_rank()].contiguous()
if sequence_parallel_enabled:
loss_weight = orig_loss_weight.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()]
else:
loss_weight = orig_loss_weight
if accumulation_in_fp16:
orig_input_tensor = orig_input_tensor.half()
input_tensor = input_tensor.half()
loss_weight = loss_weight.half()
input_tensor.requires_grad_()
output, _ = linear(input_tensor)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
self.assertIsNotNone(input_tensor.grad)
ref_linear = nn.Linear(
in_features=self.HIDDEN_SIZE,
out_features=self.HIDDEN_SIZE,
bias=False,
device="cuda",
)
with torch.no_grad():
dldy = orig_loss_weight.clone()
x = orig_input_tensor.clone()
ref_linear.weight.copy_(linear.master_weight)
if accumulation_in_fp16:
ref_linear = ref_linear.half()
x.requires_grad_()
expected_output = ref_linear(x)
expected_loss = torch.mul(expected_output, dldy).sum()
expected_loss.backward()
if not accumulation_in_fp16:
if sequence_parallel_enabled:
self.assertEqual(
x=output,
y=expected_output.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
)
else:
self.assertEqual(
x=output,
y=expected_output,
)
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1:
self.assertEqual(
x=getattr(linear.weight, grad_attr_name),
y=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
)
parallel_state.destroy_model_parallel()
def test_column_parallel_linear(self):
self._column_parallel_linear_test_impl(False, False, False, False)
def test_column_parallel_linear_async(self):
self._column_parallel_linear_test_impl(True, False, False, False)
def test_column_parallel_linear_gradient_accumulation_fusion(self):
self._column_parallel_linear_test_impl(False, True, False, False)
def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self):
self._column_parallel_linear_test_impl(False, True, True, False)
def test_column_parallel_linear_sequence_parallel(self):
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15")
self._column_parallel_linear_test_impl(False, False, False, True)
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs")
def test_column_parallel_linear_exception(self):
with self.assertRaisesRegex(
RuntimeError,
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.",
):
self._column_parallel_linear_test_impl(True, False, False, True)
def _column_parallel_linear_test_impl(
self,
async_tensor_model_parallel_allreduce: bool,
gradient_accumulation_fusion: bool,
accumulation_in_fp16: bool,
sequence_parallel_enabled: bool,
):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if async_tensor_model_parallel_allreduce and sequence_parallel_enabled:
if tensor_model_parallel_world_size == 1:
continue
with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
input_tensor_shape = self.tensor_shape
expected_output_shape = self.tensor_shape
# When sequence parallel, `gather_output` is disabled, i.e.,
# output of matmul isn't gathered in dimension of feature/hidden (last dim).
if sequence_parallel_enabled:
expected_output_shape[-1] //= tensor_model_parallel_world_size
# tensor's shape is [sequence length, batch size, hidden size]
set_random_seed(self.SEED)
linear = layers.ColumnParallelLinear(
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
bias=False,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
gather_output=not sequence_parallel_enabled,
no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce,
gradient_accumulation_fusion=gradient_accumulation_fusion,
accumulation_in_fp16=accumulation_in_fp16,
sequence_parallel_enabled=sequence_parallel_enabled,
).cuda()
if accumulation_in_fp16:
linear = linear.half()
# Simulate the situation where fusion of weight grad calculation and gradient accumulation happens.
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.zeros_like(linear.weight)
orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True)
if accumulation_in_fp16:
orig_input_tensor = orig_input_tensor.half()
if sequence_parallel_enabled:
input_tensor = list(
orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0)
)[parallel_state.get_tensor_model_parallel_rank()]
else:
input_tensor = orig_input_tensor
output, _ = linear(input_tensor)
# The order of dimension is expected to be (sequence, batch, hidden)
self.assertEqual(output.shape, expected_output_shape)
orig_loss_weight = torch.randn(input_tensor_shape, device="cuda")
if accumulation_in_fp16:
orig_loss_weight = orig_loss_weight.half()
if sequence_parallel_enabled:
loss_weight = orig_loss_weight.chunk(
tensor_model_parallel_world_size, dim=2,
)[parallel_state.get_tensor_model_parallel_rank()]
else:
loss_weight = orig_loss_weight
loss = torch.mul(output, loss_weight).sum()
loss.backward()
with torch.no_grad():
dldy = orig_loss_weight.clone()
x = orig_input_tensor.clone()
ref_linear = nn.Linear(
in_features=self.HIDDEN_SIZE,
out_features=self.HIDDEN_SIZE,
bias=False,
device="cuda",
)
if accumulation_in_fp16:
ref_linear = ref_linear.half()
# NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set.
ref_linear.weight.copy_(linear.master_weight)
x.requires_grad_()
expected_output = ref_linear(x)
if sequence_parallel_enabled:
chunk = expected_output.chunk(
tensor_model_parallel_world_size,
dim=2,
)[parallel_state.get_tensor_model_parallel_rank()]
self.assertEqual(
x=output,
y=chunk,
)
else:
self.assertEqual(
x=output,
y=expected_output,
)
expected_loss = torch.mul(expected_output, dldy).sum()
expected_loss.backward()
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1:
self.assertEqual(
x=getattr(linear.weight, grad_attr_name),
y=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
)
parallel_state.destroy_model_parallel()
class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase):
pass
class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase):
pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch
from torch.testing._internal import common_utils
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
class MappingTestBase:
def test_reduce(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
expected = torch.full(
(10, 10, 10, 10),
50 * tensor_model_paralell_world_size,
device=f"cuda:{self.rank}",
)
self.assertTrue(torch.equal(mappings._reduce(t), expected))
parallel_state.destroy_model_parallel()
def test_split(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
tensors = [
torch.randn(10, 1)
for rank in range(tensor_model_paralell_world_size)
]
x = torch.cat(tensors, 1)
out = mappings._split_along_last_dim(x)
self.assertTrue(
torch.equal(
out, tensors[parallel_state.get_tensor_model_parallel_rank()]
)
)
parallel_state.destroy_model_parallel()
def test_gather(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = f"cuda:{self.rank}"
gathered = mappings._gather_along_last_dim(
torch.tensor(
[parallel_state.get_tensor_model_parallel_rank()], device=device
)
)
expected = torch.tensor(
[rank for rank in range(tensor_model_paralell_world_size)],
device=device,
)
self.assertTrue(torch.equal(gathered, expected))
parallel_state.destroy_model_parallel()
class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
from typing import List, Optional
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import (
_reconfigure_microbatch_calculator,
get_micro_batch_size,
get_num_microbatches,
get_current_global_batch_size,
update_num_microbatches,
)
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class MicrobatchCalculatorTestBase:
GLOBAL_BATCH_SIZE: int = 1024
MICRO_BATCH_SIZE: int = 1
def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
for data_parallel_size in range(1, self.world_size + 1):
expected_global_batch_size = self.GLOBAL_BATCH_SIZE
expected_micro_batch_size = self.MICRO_BATCH_SIZE
if rampup_batch_size:
expected_global_batch_size = rampup_batch_size[0]
num_consumed_samples = 0
step_of_global_batch_size = rampup_batch_size[1]
threshold = rampup_batch_size[2]
if data_parallel_size > 1 and data_parallel_size % 2 != 0:
continue
if self.world_size % data_parallel_size != 0:
continue
with self.subTest(data_parallel_size=data_parallel_size):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=self.world_size // data_parallel_size,
pipeline_model_parallel_size_=1,
)
self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size())
_reconfigure_microbatch_calculator(
self.rank,
rampup_batch_size,
self.GLOBAL_BATCH_SIZE,
self.MICRO_BATCH_SIZE,
data_parallel_size,
)
self.assertEqual(get_micro_batch_size(), expected_micro_batch_size)
self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(current_global_batch_size, expected_global_batch_size)
# Make sure `global_batch_size` equals to the final global batch size after
# certain number of updates.
if rampup_batch_size:
update_num_microbatches(current_global_batch_size)
for i in range(100):
current_global_batch_size = get_current_global_batch_size()
update_num_microbatches(current_global_batch_size)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE)
parallel_state.destroy_model_parallel()
def test_constant_microbatch_calculator(self):
self._test(rampup_batch_size=None)
def test_dynamic_microbatch_calculator(self):
self._test(rampup_batch_size=[256, 128, 500])
class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass
class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import unittest
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.DEBUG)
# [P2P Ops Involved in Pipeline Model Parallel forward/backward]
# **forward_backward_pipelining_without_interleaving**
# - send_forward / recv_forward
# - send_backward / recv_backward
# - send_forward_recv_backward
# - send_backward_recv_forward
# **forward_backward_pipelining_with_interleaving**
# - send_backward_recv_backward
# - recv_backward
# - recv_forward
# - send_forward_backward_recv_forward_backward
# - send_forward_recv_forward
class P2PCommTestBase:
numel = 4
shape = (2, 2)
dtype = torch.float32
@property
def world_size(self):
return min(2, torch.cuda.device_count())
def _init_model_parallel(self):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=self.world_size,
virtual_pipeline_model_parallel_size_=None,
)
def create_tensor(self, value: int = None):
return torch.tensor(
[value] * self.numel).view(self.shape).to(device="cuda", dtype=self.dtype)
# Brief: Simulate warm-up.
# Brief: test `recv_forward` & `send_forward`.
def test_no_interleaving_warmup(self):
self.assertEqual(self.world_size, 2)
self._init_model_parallel()
input_tensor = None
if parallel_state.is_pipeline_first_stage():
tensor = self.create_tensor(self.rank)
print(tensor)
p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
else:
input_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
if parallel_state.is_pipeline_first_stage():
self.assertIsNone(input_tensor)
else:
expected_input_tensor = self.create_tensor(self.rank - 1)
self.assertEqual(input_tensor, expected_input_tensor)
# Brief: test `send_forward`, `send_forward_recv_forward`, and `recv_forward`.
def test_send_forward_recv_forward(self):
self._init_model_parallel()
prev_tensor = None
tensor = self.create_tensor(self.rank)
if parallel_state.is_pipeline_first_stage():
p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
elif parallel_state.is_pipeline_last_stage():
prev_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
else:
prev_tensor = p2p_communication.send_forward_recv_forward(
output_tensor=tensor,
recv_prev=True,
tensor_shape=self.shape,
dtype=self.dtype,
)
if parallel_state.is_pipeline_first_stage():
self.assertIsNone(prev_tensor)
else:
expected_prev_tensor = self.create_tensor(self.rank - 1)
self.assertEqual(prev_tensor, expected_prev_tensor)
# Brief: test `send_backward`, `send_backward_recv_backward`, and `recv_backward`.
def test_send_backward_recv_backward(self):
self._init_model_parallel()
tensor = self.create_tensor(self.rank)
next_tensor = None
if parallel_state.is_pipeline_first_stage():
next_tensor = p2p_communication.recv_backward(tensor_shape=self.shape, dtype=self.dtype)
elif parallel_state.is_pipeline_last_stage():
p2p_communication.send_backward(input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype)
else:
next_tensor = p2p_communication.send_backward_recv_backward(
input_tensor_grad=tensor,
recv_next=True,
tensor_shape=self.shape,
dtype=self.dtype,
)
if parallel_state.is_pipeline_last_stage():
self.assertIsNone(next_tensor)
else:
expected_next_tensor = self.create_tensor(self.rank + 1)
self.assertEqual(next_tensor, expected_next_tensor)
# n.b.(mkozuki): Intentionally skip NCCL backend tests as I trust pytorch/pytorch repo.
class UccP2PCommTest(P2PCommTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import os
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
os.environ["BACKEND"] = "NCCL"
DATA_PARALLEL_WORLD_SIZE: int = 1
def calc_expected_tensor_model_paralell_rank(
rank: int, tensor_model_parallel_world_size: int,
) -> int:
return rank % tensor_model_parallel_world_size
class ParallelStateTestBase:
def test_initialize_model_parallel(self) -> None:
self.assertFalse(parallel_state.model_parallel_is_initialized())
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
)
self.assertEqual(
tensor_model_parallel_world_size,
parallel_state.get_tensor_model_parallel_world_size(),
)
expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
)
self.assertEqual(
expected_tensor_model_parallel_rank,
parallel_state.get_tensor_model_parallel_rank(),
)
expected_tensor_model_parallel_src_rank = (
self.rank // tensor_model_parallel_world_size
) * tensor_model_parallel_world_size
self.assertEqual(
expected_tensor_model_parallel_src_rank,
parallel_state.get_tensor_model_parallel_src_rank(),
)
parallel_state.destroy_model_parallel()
self.assertFalse(parallel_state.model_parallel_is_initialized())
def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
if self.world_size < 4:
self.skipTest("requires >= 4 GPUs")
self.assertFalse(parallel_state.model_parallel_is_initialized())
tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
virtual_pipeline_model_parallel_world_size = 2
pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size,
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
)
self.assertEqual(
calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
),
parallel_state.get_tensor_model_parallel_rank(),
)
self.assertEqual(
pipeline_model_parallel_world_size,
parallel_state.get_pipeline_model_parallel_world_size(),
)
self.assertEqual(
virtual_pipeline_model_parallel_world_size,
parallel_state.get_virtual_pipeline_model_parallel_world_size(),
)
expected_pipeline_rank = (
self.rank - (self.rank % tensor_model_parallel_world_size)
) % pipeline_model_parallel_world_size
self.assertEqual(
expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(),
)
# virtual pipeline model parallel rank is lazily set, i.e., right after the call of
# `initialize_model_parallel`, it's set to 0.
self.assertEqual(
0, parallel_state.get_virtual_pipeline_model_parallel_rank(),
)
self.assertEqual(
pipeline_model_parallel_split_rank,
parallel_state.get_pipeline_model_parallel_split_rank(),
)
fake_split_rank = 77
parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
self.assertEqual(
fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()
)
# relative position embedding groups check
self.assertEqual(
expected_pipeline_rank < pipeline_model_parallel_split_rank,
parallel_state.is_rank_in_encoder_relative_position_embedding_group(),
)
self.assertEqual(
expected_pipeline_rank >= pipeline_model_parallel_split_rank,
parallel_state.is_rank_in_decoder_relative_position_embedding_group(),
)
parallel_state.destroy_model_parallel()
def test_initialize_model_parallel_decoder_only(self) -> None:
"""Initialize model parallelism for decoder-only Transformers like GPT-3"""
self.assertFalse(parallel_state.model_parallel_is_initialized())
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
pipeline_model_parallel_split_rank_=0,
)
self.assertEqual(
tensor_model_parallel_world_size,
parallel_state.get_tensor_model_parallel_world_size(),
)
expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
)
self.assertEqual(
expected_tensor_model_parallel_rank,
parallel_state.get_tensor_model_parallel_rank(),
)
expected_tensor_model_parallel_src_rank = (
self.rank // tensor_model_parallel_world_size
) * tensor_model_parallel_world_size
self.assertEqual(
expected_tensor_model_parallel_src_rank,
parallel_state.get_tensor_model_parallel_src_rank(),
)
parallel_state.destroy_model_parallel()
self.assertFalse(parallel_state.model_parallel_is_initialized())
class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass
class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import itertools
import re
from typing import Optional, Tuple, List
import unittest
import torch
from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import utils as pp_utils
from apex.transformer.pipeline_parallel.schedules.common import (
FwdStepFunc,
build_model,
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
from apex.transformer.testing import commons as testing_utils
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
weight_coeff = 1024
def get_init_weights_func(offset: int = 0):
@torch.no_grad()
def init_weights(m):
rank = parallel_state.get_pipeline_model_parallel_rank()
if isinstance(m, torch.nn.Linear):
m.weight.fill_((rank + offset + 1.0) / weight_coeff)
m.bias.fill_(1.0)
return init_weights
def get_dtype_for_comparison():
if(torch.cuda.get_device_capability() >= (8, 0)):
return torch.float64
return torch.float32
def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> Tuple[torch.Tensor, List[torch.Tensor]]:
model = []
dtype = get_dtype_for_comparison()
data = torch.ones(global_batch_shape, dtype=dtype)
for i in range(total_layers):
w = torch.ones((hidden_size, hidden_size), dtype=dtype) * (i + 1.0) / weight_coeff
b = torch.ones(hidden_size, dtype=dtype)
w.requires_grad_()
b.requires_grad_()
# don't need to care about transpose semantics as all values are the same
data = torch.matmul(w, data) + b
model.append([w, b])
loss = data.sum() / global_batch_shape[0]
loss.backward()
return loss, model
def _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size: Optional[int] = None
) -> Tuple[int, int, int]:
# TODO: revisit if we can fold this into the class for skip logic / avoid duplication
# of world size computation
world_size = torch.cuda.device_count()
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0)
if pipeline_model_parallel_world_size is None:
pipeline_model_parallel_world_size = world_size // (tensor_model_parallel_world_size * data_parallel_size)
else:
data_parallel_size = world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size)
return tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size
class PipelineParallelForwardBackwardTestBase:
GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 2
HIDDEN_SIZE = 32
deallocate_options = (True, False)
# If :obj:`None`, (torch.float32, torch.float16, torch.bfloat16) are dtype options on Ampere.
# You can limit the options by overriding the following `dtypes`.
dtypes = None
def _forward_backward_test_impl(
self,
forward_only: bool,
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
self.assertIsNotNone(virtual_pipeline_model_parallel_size)
self.assertGreater(virtual_pipeline_model_parallel_size, 1)
dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product(
dtype_options, self.deallocate_options,
):
grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0)
if dtype == torch.half
else None
)
(tensor_model_parallel_world_size,
data_parallel_size,
pipeline_model_parallel_world_size) = _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
default_backend=default_backend,
p2p_backend=p2p_backend,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
global_batch_shape = (
self.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(),
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
)
batch = None
if parallel_state.is_pipeline_first_stage():
batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )
model = build_model(
testing_utils.model_provider_func,
# Use DDP only when it's better to have
wrap_with_ddp=data_parallel_size > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=self.HIDDEN_SIZE,
)
offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
for idx, model_module in enumerate(model):
model_module = model_module.to(dtype)
model_module.apply(get_init_weights_func(idx*offset))
_param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
pp_utils.update_num_microbatches(0)
loss = fwd_bwd_func(
testing_utils.fwd_step_func,
batch,
model,
forward_only=forward_only,
# `tensor_shape` is the shape of micro batch.
tensor_shape=(
self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
),
dtype=dtype,
async_comm=async_comm,
grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
if dtype == get_dtype_for_comparison():
torch.cuda.synchronize()
hidden_size = self.HIDDEN_SIZE
microbatch_size = self.MICRO_BATCH_SIZE
total_layers = pipeline_model_parallel_world_size
if virtual_pipeline_model_parallel_size is not None:
total_layers *= virtual_pipeline_model_parallel_size
target_loss, target_model = get_target_loss_and_model(global_batch_shape, hidden_size, total_layers)
for loss_item in loss:
x = loss_item['avg']
self.assertEqual(x.item() / microbatch_size, target_loss.item())
if not forward_only:
for vm_id, model_module in enumerate(model):
params = list(model_module.parameters())
rank = params[0].get_device()
offset = pipeline_model_parallel_world_size
param_id = rank // data_parallel_size + vm_id * offset
target_params = target_model[param_id]
self.assertEqual(params[0].cpu(), target_params[0])
self.assertEqual(params[1].cpu(), target_params[1])
self.assertEqual(params[0].grad.cpu() / microbatch_size, target_params[0].grad)
self.assertEqual(params[1].grad.cpu() / microbatch_size, target_params[1].grad)
if not forward_only:
for m in model:
for p in m.parameters():
self.assertIsNotNone(p.grad)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
parallel_state.destroy_model_parallel()
def test_learning_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)
def test_inference_no_pipelining(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_learning_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_inference_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_learning_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_inference_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _run_hybrid_distributed_backend(self, forward_only: bool) -> None:
self._forward_backward_test_impl(
forward_only, forward_backward_pipelining_without_interleaving, None, None,
default_backend="nccl", p2p_backend="ucc",
)
@unittest.skipUnless(HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, "Needs driver >= 470.42.01")
def _test_hybrid_backends(self, forward_only: bool) -> None:
if HAS_TORCH_UCC:
self._run_hybrid_distributed_backend(forward_only)
else:
with self.assertRaisesRegex(
ImportError,
re.escape("UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"),
):
self._run_hybrid_distributed_backend(forward_only)
def test_learning_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(False)
def test_inference_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(True)
# n.b.(mkozuki): pipeline parallel w/o interleaving with UCX_TLS=tcp,sm fails.
class UccPipelineParallelForwardBackwardTest(UccDistributedTestBase, PipelineParallelForwardBackwardTestBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
deallocate_options = (False,)
dtypes = (torch.float32,)
# Sanity checking the functionality of `forward_backward_pipelining_without_interleaving` with
# `model_type=ModelType.encoder_and_decoder` which is used for pipeline training of transformer
# models such as T5.
@unittest.skipIf(torch.cuda.device_count() < 4, "Requires >= 4 GPUs")
class NcclPipelineParallelWithToyParallelMLP(NcclDistributedTestBase):
GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 2
HIDDEN_SIZE = 64
# TODO(mkozuki): Change `DECODER_SEQUENCE_LENGTH` to a value different from `ENCODER_SEQUENCE_LENGTH`.
# To test forward_backward_pipelining_without_interleaving with `model_type=ModelType.encoder_and_decoder`,
# `decoder_seq_length` is necessary and ideally should be different from `encoder_sequence_length`
# but my laziness let me use the same value.
# Note that you may have to either update `MyModel` def or define another `MyModel`.
# to support different `DECODER_SEQUENCE_LENGTH`.
ENCODER_SEQUENCE_LENGTH = 32
DECODER_SEQUENCE_LENGTH = 32
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
# TODO(mkozuki): Add cases of async_comm=True
# TODO(mkozuki): Add loss check.
# TODO(mkozuki): Call `build_model` with `model_type`.
# TODO(mkozuki): Set `tensor_model_parallel>1` for encoder_and_decoder as well if there's enough GPUs
# in order to let `sequence_parallel_enabled` have an effect on tensor shape logic.
def _forward_backward_test_impl(
self,
*,
forward_only: bool,
sequence_parallel_enabled: bool,
model_type: ModelType,
dtype: torch.dtype = torch.float32,
) -> None:
# N.B.(mkozuki): It might be better to set `tensor_model_parallel_size` to >1
# if `self.world_size > 5`. Otherwise, `pipeline_model_parallel_split_rank`
# can be 1, which can be too far real usecase.
tensor_model_parallel_size = 1 + int(self.world_size >= 4)
pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_size
if model_type == ModelType.encoder_and_decoder:
pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2
else:
pipeline_model_parallel_split_rank = None
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
)
testing_utils.set_random_seed(567)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
model = build_model(
testing_utils.mlp_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=None,
hidden_size=self.HIDDEN_SIZE,
sequence_parallel_enabled=sequence_parallel_enabled,
)
model = [m.to(dtype=dtype) for m in model]
if parallel_state.is_pipeline_first_stage():
batch: Tuple[torch.Tensor] = (
torch.ones(
(self.GLOBAL_BATCH_SIZE, self.ENCODER_SEQUENCE_LENGTH, self.HIDDEN_SIZE),
dtype=dtype,
device="cuda",
),
)
else:
batch = None
forward_backward_pipelining_without_interleaving(
forward_step_func=testing_utils.ToyParallelMLPFwdBwdStepFunc(
sequence_parallel_enabled=sequence_parallel_enabled,
),
batch=batch,
model=model,
forward_only=forward_only,
tensor_shape=(
self.ENCODER_SEQUENCE_LENGTH,
self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE,
),
model_type=model_type,
decoder_sequence_length=self.DECODER_SEQUENCE_LENGTH,
async_comm=False,
grad_scaler=None,
deallocate_pipeline_outputs=False,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def test_pipelining_without_interleaving_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=False, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_inferenc_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=True, sequence_parallel_enabled=False, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_sequence_paralle_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_inference_sequence_paralle_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=True, sequence_parallel_enabled=True, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_encoder_or_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=False, model_type=ModelType.encoder_or_decoder)
def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_or_decoder)
def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder_half(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_or_decoder, dtype=torch.half)
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerRandomTestBase:
def test_set_cuda_rng_state(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
size, seed = 123, 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
rng_state = torch.cuda.get_rng_state()
rng_state_clone = rng_state.clone()
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)
self.assertGreater(
torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0
)
new_rng_state = torch.cuda.get_rng_state()
self.assertGreater(new_rng_state.sub(rng_state).max(), 0)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
self.assertEqual(result_2, result_1)
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)
parallel_state.destroy_model_parallel()
def test_cuda_rng_tracker(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
seed_1, seed_2, size = 1234, 4321, [12, 21]
tensor = torch.cuda.FloatTensor(size)
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
targt_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
self.assertEqual(target_11, result_11)
self.assertEqual(target_12, result_12)
self.assertEqual(targt_21, result_21)
self.assertEqual(target_22, result_22)
self.assertNotEqual(result_11, result_21)
self.assertNotEqual(result_21, result_22)
tensor_parallel.random.get_cuda_rng_tracker().reset()
parallel_state.destroy_model_parallel()
class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass
class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -5,31 +5,26 @@ import sys
import unittest
DENY_TEST = [
"megatron_gpt_pipeline",
]
MULTIGPU_TEST = [
"pipeline_parallel_test",
"dynamic_batchsize_test",
]
SEVERALGPU_TEST = [
"bert_minimal_test",
"gpt_minimal_test",
"dynamic_batchsize_test",
]
def get_multigpu_launch_option(min_gpu):
should_skip = False
import torch
num_devices = torch.cuda.device_count()
if num_devices < min_gpu:
should_skip = True
distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}"
return should_skip, distributed_run_options
def get_launch_option(test_filename) -> Tuple[bool, str]:
should_skip = False
for multigpu_test in MULTIGPU_TEST:
if multigpu_test in test_filename:
return get_multigpu_launch_option(2)
for severalgpu_test in SEVERALGPU_TEST:
if severalgpu_test in test_filename:
return get_multigpu_launch_option(3)
......@@ -38,11 +33,10 @@ def get_launch_option(test_filename) -> Tuple[bool, str]:
def run_transformer_tests():
python_executable_path = sys.executable
# repository_root = os.path.join(os.path.dirname(__file__), "../../../")
# directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
directory = os.path.dirname(__file__)
files = [
os.path.join(directory, f) for f in os.listdir(directory)
os.path.join(directory, f)
for f in os.listdir(directory)
if f.startswith("run_") and os.path.isfile(os.path.join(directory, f))
]
print("#######################################################")
......@@ -52,36 +46,45 @@ def run_transformer_tests():
errors = []
for i, test_file in enumerate(files, 1):
is_denied = False
for deny_file in DENY_TEST:
if deny_file in test_file:
is_denied = True
if is_denied:
print(f"### {i} / {len(files)}: {test_file} skipped")
continue
should_skip, launch_option = get_launch_option(test_file)
if should_skip:
print(f"### {i} / {len(files)}: {test_file} skipped. Requires multiple GPUs.")
print(
f"### {i} / {len(files)}: {test_file} skipped. Requires multiple GPUs."
)
continue
test_run_cmd = (
f"{python_executable_path} {launch_option} {test_file} "
"--micro-batch-size 4 --num-layers 16 --hidden-size 768 --num-attention-heads 8 --max-position-embeddings "
"512 --seq-length 512 --global-batch-size 256"
"--micro-batch-size 2 --num-layers 16 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings "
"512 --seq-length 512 --global-batch-size 128"
)
if 'bert' in test_file:
if "bert" in test_file or "gpt" in test_file:
import torch
num_devices = torch.cuda.device_count()
test_run_cmd += f" --pipeline-model-parallel-size {num_devices}"
if "bert" in test_file:
# "bert" uses the interleaving.
tensor_model_parallel_size = 2 if num_devices % 2 == 0 and num_devices > 4 else 1
if "gpt" in test_file:
# "gpt" uses the non-interleaving.
tensor_model_parallel_size = 2 if num_devices % 2 == 0 and num_devices >= 4 else 1
pipeline_model_parallel_size = num_devices // tensor_model_parallel_size
test_run_cmd += f" --pipeline-model-parallel-size {pipeline_model_parallel_size} --tensor-model-parallel-size {tensor_model_parallel_size}"
if "bert" in test_file:
test_run_cmd += f" --bert-no-binary-head"
else:
test_run_cmd += f" --use-cpu-initialization"
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
try:
output = subprocess.check_output(
test_run_cmd, shell=True
).decode(sys.stdout.encoding).strip()
output = (
subprocess.check_output(test_run_cmd, shell=True)
.decode(sys.stdout.encoding)
.strip()
)
except Exception as e:
errors.append((test_file, str(e)))
else:
if '>> passed the test :-)' not in output:
if ">> passed the test :-)" not in output:
errors.append((test_file, output))
else:
if not errors:
......@@ -96,10 +99,9 @@ def run_transformer_tests():
class TestTransformer(unittest.TestCase):
def test_transformer(self):
run_transformer_tests()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import utils
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerUtilsTest(NcclDistributedTestBase):
def test_split_tensor_along_last_dim(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = "cpu"
input_tensor = torch.randn((100, 100, 100), device=device)
splits = utils.split_tensor_along_last_dim(input_tensor, 10)
last_dim_shapes = torch.tensor(
[int(split.size()[-1]) for split in splits]
)
self.assertTrue(torch.equal(last_dim_shapes, torch.full((10,), 10),))
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import os
import logging
import itertools
from typing import Optional, Tuple, List
import unittest
import torch
from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda
from torch.testing._internal import common_distributed
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import utils as pp_utils
from apex.transformer.pipeline_parallel.schedules.common import (
FwdStepFunc,
build_model,
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
from apex.transformer.testing import commons as testing_utils
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
def _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size: Optional[int] = None
) -> Tuple[int, int, int]:
# TODO: revisit if we can fold this into the class for skip logic / avoid duplication
# of world size computation
world_size = torch.cuda.device_count()
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0)
if pipeline_model_parallel_world_size is None:
pipeline_model_parallel_world_size = world_size // (tensor_model_parallel_world_size * data_parallel_size)
else:
data_parallel_size = world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size)
return tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size
class UccPipelineParallelForwardBackwardProf(UccDistributedTestBase):
# The purpose of this class is to test and confirm asynchronous communication via profiling.
# Having that in mind, it is safe to skip all the numerical checks.
# For unit testing with numerical checks please refer to `tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py`.
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.GLOBAL_BATCH_SIZE = 1024
self.MICRO_BATCH_SIZE = 64
self.HIDDEN_SIZE = 256
self.NUM_FWD_BWD_ITERATIONS = 4
self.deallocate_options = (False,)
self.dtypes = (torch.float32,)
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _forward_backward_test_impl(
self,
forward_only: bool,
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
self.assertIsNotNone(virtual_pipeline_model_parallel_size)
self.assertGreater(virtual_pipeline_model_parallel_size, 1)
dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product(
dtype_options, self.deallocate_options,
):
grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0)
if dtype == torch.half
else None
)
(tensor_model_parallel_world_size,
data_parallel_size,
pipeline_model_parallel_world_size) = _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
default_backend=default_backend,
p2p_backend=p2p_backend,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
global_batch_shape = (
self.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(),
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
)
batch = None
if parallel_state.is_pipeline_first_stage():
batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )
model = build_model(
testing_utils.model_provider_func,
# Use DDP only when it's better to have
wrap_with_ddp=data_parallel_size > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=self.HIDDEN_SIZE,
)
offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
for idx, model_module in enumerate(model):
model_module = model_module.to(dtype)
_param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
pp_utils.update_num_microbatches(0)
for _ in range(self.NUM_FWD_BWD_ITERATIONS):
loss = fwd_bwd_func(
testing_utils.fwd_step_func,
batch,
model,
forward_only=forward_only,
# `tensor_shape` is the shape of micro batch.
tensor_shape=(
self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
),
dtype=dtype,
async_comm=async_comm,
grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
parallel_state.destroy_model_parallel()
def test_learning_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)
def test_inference_no_pipelining(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_learning_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_inference_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_learning_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_inference_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
if __name__ == "__main__":
os.environ["UCC_TLS"] = "ucp,cuda"
common_distributed.TIMEOUT_DEFAULT = 500
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment