Unverified Commit 63d5dd63 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Pipeline Model Parallel (#1202)



* Init apex.ppu (pipeline model parallel utility)

Reference commit:

```
commit 5ab646376d67831601d5552c193241d017f1b35c (HEAD -> main, internal/main)
Merge: 14f2c684 7b293d9b
Author: Mohammad Shoeybi <mshoeybi@nvidia.com>
Date:   Wed Sep 22 22:57:54 2021 -0700

    Merge branch 'add_BOS' into 'main'

    Add Beginning of Sentence token option and adding semaphore while multi-threading to prevent crashes and hangs due to connection keep-alives

    See merge request ADLR/megatron-lm!328
```

* removing get_args and replace import - phase 1

* removing get_args and replace import - phase 2

* move ppu to apex.transformer.pipeline_parallel

* update two __init__.py

* update READMEs

* mpu -> parallel_state & tensor_parallel

* fix

* remove not pipeline files

* separate schedules.py - phase 1

* dissect schedules.py

* data_iterators -> batch

* remove optimizer from forward_backward_step funcs

* init test

* Apply 2 suggestion(s) to 2 file(s)

* fix cyclic import

* fix syntax of Callable

* fix - 1

* move directory as testing used for pp test as well

* add some functions for num microbatches calculator

* model is a list in pipeline parallel

* skip build num microbatch calculator

* fix test

* assert -> raise

* skip args printing

* specify tensor shape everywhere even if None - phase 1

* private timers

* passing tensor shape & dtype around

* update dtype handling by introducing helper func

* write helper func to reduce cyclomatic complexity

* remove duplicate

* update

* move split_tensor_into_1d_equal_chunks to avoid cyclic import

* tmp

* cosmetic

* move gather_split_1d_tensor to avoid cyclic imports

* remove debug print

* add outer loop

* early return if possible

* cosmetic

* passing around tensor shape

* refactor test

* add script to learn batch sampler behavior

* update

* minibatch splitter

* add minibatch splitter

* split minibatch into microbatches

* minor changes

* uncomment split batch for test sake

* set as attribute

* study the behavior of no pipelining

* debug 1

* reflect test util namespace change

* update readme

* cosmetic in test

* add model build helper func for interleaving shced

* adding model builder from megatron

* canbe cyclic import

* fix

* enable interleaving test, but failing even if forward only

* fix batch preparation

* add explanation

* print data parallel size

* fix typo

* Add Megatron style GPT model by Rishi
Co-authored-by: default avatarRishi Puri <riship@nvidia.com>

* update

* type hint for jit

* fix forward_backward_no_pipelining test

* pipeline forward backward seem to hang if not forward only

* fix typo

* debug

* add p2p test

* simplify

* fix

* tentative

* set both tmp and pmp to 1

* init

* fix typo

* fix

* fix path of divide

* set seed for tmp

* update upon Eddie comment

* fix typo

* adding failing data loader test

* fix

* megatron still failing

* check in

* with the nested loop of new order, interleaving seems fine

* cosmetic change

* make `forward_backward_pipelining_with_interleaving private

* warn users that interleaving sched is unstable

* move noop handler to no pipelining

* comment out rank_print

* make `build_model` more flexible

* skip megatron test tentatively

* correctly comment out rank_print

* correctly comment out rank_print

* correctly comment out rank_print

* skip appropriately

* remove wip p2p comm test

* update type hint of model_provider_func

* disable tf32 in each test script

* skip interleaving w/ backward

* rename as mpu is the old name

* remove broken case

* expose build_model func

* delete `dist.ring_exchange` func call and `use_ring_exchange` argument

* nit fixes

* check in

* remove unused file

* update the list

* update tensor shape

* remove mixed dtype case

* use torch.distributed.run

* 2020 -> 2021

* another 2020 -> 2021

* docstring & type hint

* fix teardown

* update

* change to experimental

* check if warned
Co-authored-by: default avatarRishi Puri <riship@nvidia.com>
Co-authored-by: default avatarEddie Yan <eddiey@nvidia.com>
parent 3303b3e7
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,10 +19,10 @@ import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
......@@ -82,6 +82,8 @@ def test_broadcast_data(tensor_model_parallel_size):
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,10 +15,10 @@
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
......@@ -90,6 +90,8 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,11 +18,11 @@ from torch.nn.parameter import Parameter
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
......@@ -584,7 +584,6 @@ def test_parallel_transformer_layer(tensor_model_parallel_size):
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
......
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel import mappings
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import initialize_distributed
global_vars.set_global_variables()
......@@ -48,6 +48,8 @@ def test__gather(args, tensor_model_parallel_size):
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
......
from functools import partial
from typing import List
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.utils import rank_print
global_vars.set_global_variables()
N_VOCAB = 8192
def generate_batch(batch_size, sequence_length):
size = batch_size, sequence_length + 1
int_tensor = torch.randint(low=0, high=N_VOCAB, size=size, dtype=torch.long).cuda()
return int_tensor,
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L44
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
tokens_ = data.long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
N_VOCAB, # tokenizer.eod,
False, # args.reset_position_ids,
False, # args.reset_attention_mask,
False, # args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
# TODO (mkozuki): Currently I'm seeing no attribute `word_embeddings` which looks weird.
def forward_step(batch, model):
"""Forward step."""
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=None, forward_only=False):
parallel_state.initialize_model_parallel(1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
model_parallel_cuda_manual_seed(42)
model = build_model(
gpt_model_provider, True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size)
# rank_print("building model")
assert isinstance(model, list)
assert len(model) == (1 or virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups)
if parallel_state.is_pipeline_last_stage():
# rank_print("checking `word_embeddings` existence")
for m in model:
assert hasattr(m, "word_embeddings")
args = global_vars.get_args()
if virtual_pipeline_model_parallel_size is None:
batch = generate_batch(args.global_batch_size, args.seq_length)
else:
batch = [generate_batch(args.global_batch_size, args.seq_length) for _ in range(virtual_pipeline_model_parallel_size)]
# rank_print("preparing batch")
if virtual_pipeline_model_parallel_size is None:
fwd_bwd_func = forward_backward_pipelining_without_interleaving
else:
fwd_bwd_func = _forward_backward_pipelining_with_interleaving
# rank_print(f"selecting forward_backward func: {fwd_bwd_func}")
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# rank_print(f"`tensor_shape`: {tensor_shape}")
fwd_bwd_func(forward_step, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
# rank_print(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
initialize_distributed()
args = global_vars.get_args()
args.padded_vocab_size = N_VOCAB
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
update_num_microbatches(0, True)
print_separator("run GPT model")
try:
run_gpt(torch.distributed.get_world_size())
# TODO(mkozuki): handle exception correctly, but for now, lazily commenting out as
# this won't get kicked by CI
except Exception as e:
# rank_print(str(e))
pass
finally:
parallel_state.destroy_model_parallel()
from typing import Optional, Union, List
import torch
import torch.nn as nn
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.utils import rank_print
global_vars.set_global_variables()
batch_size, micro_batch_size = None, None
hidden_size = 16
fwd_bwd_functions = {
"no_pipelining": forward_backward_no_pipelining,
"no_interleaving": forward_backward_pipelining_without_interleaving,
"interleaving": _forward_backward_pipelining_with_interleaving,
}
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, pre_process: bool = False, post_process: bool = False) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(pre_process=pre_process, post_process=post_process)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
return self.layer(x)
return self.layer(self.input_tensor)
def model_provider_func(pre_process, post_process) -> MyModel:
return MyModel(pre_process, post_process)
def process_batch(batch):
if isinstance(batch, list):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(batch, model):
x = process_batch(batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss}
return y, loss_func
# TODO (mkozuki): Add a case with `autocast` and `GradScaler`.
# Run forward & backward for one minibatch.
def forward_backward_func_template(
name: str,
forward_backward_func,
pipeline_model_parallel_size: int,
forward_only: bool,
) -> None:
print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}")
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
if name == "no_pipelining":
# note (mkozuki): `forward_backward_no_pipelining` is **NOTE** compatible with
# pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
# tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
parallel_state.initialize_model_parallel(1, 1, None)
else:
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
if virtual_pipeline_model_parallel_size is not None:
# Check the experimental warning message
get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
model = build_model(
model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size]
if virtual_pipeline_model_parallel_size is None:
batch = (torch.randn(tensor_shape).cuda(),)
else:
batch = [(torch.randn(tensor_shape).cuda(),) for _ in range(virtual_pipeline_model_parallel_size)]
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
if not forward_only:
# rank_print("grad check")
for m in model:
for p in m.parameters():
if p.grad is None:
raise RuntimeError("grad not found")
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
n_tests = 0
failures = []
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items():
n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size
try:
forward_backward_func_template(
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
finally:
parallel_state.destroy_model_parallel()
else:
print_separator(f"{name} works")
print_separator("TEST RESULT")
if failures:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("\n".join(failures))
msg = f"{len(failures)} / {n_tests} cases failed"
raise RuntimeError(msg)
else:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("### PASS!")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,10 +16,10 @@ import torch
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
......@@ -188,6 +188,8 @@ def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
......
......@@ -15,6 +15,8 @@ def test_split_tensor_along_last_dim():
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
test_divide()
test_split_tensor_along_last_dim()
print(">> passed the test :-)")
from itertools import product
import unittest
import torch
from torch.utils.data import Dataset
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from apex.transformer.pipeline_parallel.utils import _split_batch_into_microbatch as split_batch_into_microbatch
class MyIterableDataset(Dataset):
def __init__(self, start, end):
super().__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
self.samples = list(range(self.start, self.end))
def __iter__(self):
return iter(range(self.start, self.end))
def __getitem__(self, index):
return self.samples[index]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
# Samples 8 tensors in total.
# First sample 4 tensors twice, then sample 2 tensors fourth.
class TestBatchSamplerBehavior(unittest.TestCase):
def test_batch_sampler_behavior(self):
dataset = MyIterableDataset(0, 100)
for num_workers in (1, 2, 4):
with self.subTest(f"{num_workers}"):
torch.manual_seed(42)
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 4, 0, 1), num_workers=num_workers)
samples = []
for i, batch in enumerate(loader):
samples.append(batch)
if i == 2 - 1:
break
torch.manual_seed(42)
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 2, 0, 1), num_workers=num_workers)
samples2 = []
for i, batch in enumerate(loader):
samples2.append(batch)
if i == 4 - 1:
break
torch.testing.assert_allclose(torch.cat(samples), torch.cat(samples2))
def test_split_batch(self):
class MyIterableDataset(Dataset):
def __init__(self, start, end):
super().__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
self.samples = list(range(self.start, self.end))
def __len__(self):
return self.end - self.start
def __iter__(self):
return iter(range(self.start, self.end))
def __getitem__(self, index):
return (torch.tensor([index, index]), torch.tensor([index // 2, index // 2]))
dataset = MyIterableDataset(0, 100)
torch.manual_seed(42)
global_batch_size = 16
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
batch = next(iter(loader))
# samples = None
# for i, batch in enumerate(loader):
# # samples = batch
# if i == 0:
# break
for _micro_batch_size in (1, 2, 4, 8):
microbatches = list(split_batch_into_microbatch(
batch,
_micro_batch_size=_micro_batch_size,
_global_batch_size=global_batch_size,
))
# print(batch)
# print(microbatches)
self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
if __name__ == "__main__":
unittest.main()
from typing import Tuple
import os
import subprocess
import sys
import unittest
def run_mpu_tests():
DENY_TEST = [
"megatron_gpt_pipeline",
]
MULTIGPU_TEST = [
"pipeline_parallel_test",
]
def get_launch_option(test_filename) -> Tuple[bool, str]:
should_skip = False
for multigpu_test in MULTIGPU_TEST:
if multigpu_test in test_filename:
import torch
num_devices = torch.cuda.device_count()
if num_devices < 2:
should_skip = True
distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}"
return should_skip, distributed_run_options
return should_skip, ""
def run_transformer_tests():
python_executable_path = sys.executable
# repository_root = os.path.join(os.path.dirname(__file__), "../../../")
# directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
......@@ -19,7 +41,22 @@ def run_mpu_tests():
print("#######################################################")
errors = []
for i, test_file in enumerate(files, 1):
test_run_cmd = f"NVIDIA_TF32_OVERRIDE=0 {python_executable_path} {test_file} --micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings 32 --encoder-seq-length 32 --use-cpu-initialization" # NOQA
is_denied = False
for deny_file in DENY_TEST:
if deny_file in test_file:
is_denied = True
if is_denied:
print(f"### {i} / {len(files)}: {test_file} skipped")
continue
should_skip, launch_option = get_launch_option(test_file)
if should_skip:
print(f"### {i} / {len(files)}: {test_file} skipped. Requires multiple GPUs.")
continue
test_run_cmd = (
f"{python_executable_path} {launch_option} {test_file} "
"--micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings "
"32 --encoder-seq-length 32 --use-cpu-initialization"
)
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
try:
output = subprocess.check_output(
......@@ -29,7 +66,7 @@ def run_mpu_tests():
errors.append((test_file, str(e)))
else:
if '>> passed the test :-)' not in output:
errors.append(test_file, output)
errors.append((test_file, output))
else:
if not errors:
print("### PASSED")
......@@ -42,10 +79,10 @@ def run_mpu_tests():
raise RuntimeError(short_msg)
class TestMPU(unittest.TestCase):
class TestTransformer(unittest.TestCase):
def test_mpu(self):
run_mpu_tests()
def test_transformer(self):
run_transformer_tests()
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment