Unverified Commit a0ed4151 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Format & Test Refactoring (#1325)

* try PyTorch custom TestCase class

* revert

* initial working example

* update

* data utils

* fix imports

* hardcode backend to nccl

* fix signature

* fix typo

* mapping

* set device

* init

* refactor x entropy

* remove unused import & destroy model parallel

* refactor random

* fix test

* remove migrated tests

* refactor

* init

* separate affine weight init

* init model parallel

* split more

* weight init fix part 1

* use cpu init for consistency btwn native and tensor parallel

* black

* add col parallel

* use a 3D tensor of square matrix for column parallel linear

* skip the failing cases

* migrate layers test

* pipeline parallel forward/backward

* fix typo

* fix typo

* fix

* fix pipeline world size

* black

* rm `run_pipeline_parallel_test` in favor of test_pipeline_parallel_fwd_bwd.py

* stop logging

* set log level

* black

* license and format

* fix

* skip tf32 as matrices are small

* remove potentially inappropriate license

* Apply suggestions from code review

* remove `TODO` comment

* `torch.testing.assert_allclose` -> `torch.testing.assert_close`

* remove comment-outs

* remote unused import

* minor fix
parent f10b4b89
...@@ -5,10 +5,14 @@ from apex.transformer import tensor_parallel ...@@ -5,10 +5,14 @@ from apex.transformer import tensor_parallel
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.testing.standalone_bert import bert_model_provider from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
...@@ -17,9 +21,12 @@ from apex.transformer.testing.commons import initialize_distributed ...@@ -17,9 +21,12 @@ from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator from apex.transformer.testing.commons import print_separator
import warnings import warnings
class DebugWarning(Warning): class DebugWarning(Warning):
pass pass
mode = None mode = None
MANUAL_SEED = 42 MANUAL_SEED = 42
inds = None inds = None
...@@ -30,62 +37,74 @@ EASY_MODE = False ...@@ -30,62 +37,74 @@ EASY_MODE = False
EASY_MODE_SIZ = 32 EASY_MODE_SIZ = 32
ONCE = False ONCE = False
def download_fancy_data(): def download_fancy_data():
#import requests # import requests
#response = requests.get('https://internet.com/book.txt') # response = requests.get('https://internet.com/book.txt')
#text = ' '.join(response.text.split()) # text = ' '.join(response.text.split())
text = """ text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
""" """
text = text*1024 text = text * 1024
encoded = text.encode('ascii', 'replace') encoded = text.encode("ascii", "replace")
ints = [int(encoded[i]) for i in range(len(encoded))] ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints) return torch.tensor(ints)
# build a batch given sequence_len and batch size # build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size): def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx global data_idx
global inds global inds
global masks global masks
global MANUAL_SEED global MANUAL_SEED
temps = [] temps = []
for i in range(batch_size): for i in range(batch_size):
if inds is None or data_idx >= len(inds): if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different # hack as use of RNG will fall out of sync due to pipelines being different
torch.manual_seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device='cuda') inds = torch.randperm(effective_length, device="cuda")
masks = (torch.rand(len(inds)//batch_size + 1, batch_size, sequence_len, device='cuda') >= MASK_PROB).long() masks = (
MANUAL_SEED += 1 torch.rand(
print("new epoch", len(inds)) len(inds) // batch_size + 1, batch_size, sequence_len, device="cuda"
data_idx = 0 )
print("my start", inds[0:5]) >= MASK_PROB
print("masks_checksum:", torch.sum(masks)) ).long()
if EASY_MODE: MANUAL_SEED += 1
data_idx_ = data_idx % EASY_MODE_SIZ print("new epoch", len(inds))
else: data_idx = 0
data_idx_ = data_idx print("my start", inds[0:5])
offset = inds[data_idx_] #* SEQUENCE_LEN print("masks_checksum:", torch.sum(masks))
data_idx += 1 if EASY_MODE:
data_idx_ = data_idx % EASY_MODE_SIZ
curr = fancy_data[offset:offset+sequence_len].clone().detach() else:
temps.append(curr) data_idx_ = data_idx
temp = torch.stack(temps, dim=0).cuda() offset = inds[data_idx_] # * SEQUENCE_LEN
mask = masks[data_idx//batch_size] data_idx += 1
mask_not = torch.logical_not(mask).long()
data = mask * temp + mask_not*124 curr = fancy_data[offset : offset + sequence_len].clone().detach()
label = temp temps.append(curr)
if parallel_state.get_tensor_model_parallel_rank() == 0: temp = torch.stack(temps, dim=0).cuda()
data_dict = {"text": data, "label": label, "mask_not": mask_not} mask = masks[data_idx // batch_size]
else: mask_not = torch.logical_not(mask).long()
data_dict = None data = mask * temp + mask_not * 124
keys = ["text", "label", "mask_not"] label = temp
dtype = torch.int64 if parallel_state.get_tensor_model_parallel_rank() == 0:
broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long) data_dict = {"text": data, "label": label, "mask_not": mask_not}
return (broadcasted_data["text"].long(), broadcasted_data["label"].long(), broadcasted_data["mask_not"]) else:
data_dict = None
keys = ["text", "label", "mask_not"]
dtype = torch.int64
broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long)
return (
broadcasted_data["text"].long(),
broadcasted_data["label"].long(),
broadcasted_data["mask_not"],
)
easy_data = None easy_data = None
def fwd_step_func(batch, model): def fwd_step_func(batch, model):
data, label, loss_mask = batch data, label, loss_mask = batch
y = model(data, torch.ones_like(data), lm_labels=label) y = model(data, torch.ones_like(data), lm_labels=label)
...@@ -94,31 +113,38 @@ def fwd_step_func(batch, model): ...@@ -94,31 +113,38 @@ def fwd_step_func(batch, model):
global ONCE global ONCE
output_tensor, _ = output_tensor output_tensor, _ = output_tensor
lm_loss_ = output_tensor.float() lm_loss_ = output_tensor.float()
lm_loss = torch.sum( lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
averaged_loss = average_losses_across_data_parallel_group([lm_loss]) averaged_loss = average_losses_across_data_parallel_group([lm_loss])
if data_idx >= 1536: if data_idx >= 1536:
assert lm_loss < 4.8 assert lm_loss < 4.8
if not ONCE: if not ONCE:
print("LOSS OK") print("LOSS OK")
ONCE = True ONCE = True
return lm_loss, {'avg': averaged_loss} return lm_loss, {"avg": averaged_loss}
return y, loss_func return y, loss_func
def train(model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size): def train(
model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
):
sequence_len = global_vars.get_args().seq_length sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size hidden_size = global_vars.get_args().hidden_size
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) forward_backward_func = get_forward_backward_func(
virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
)
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
for _ in range(16): for _ in range(16):
batch = generate_fancy_data_labels(sequence_len, batch_size) batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad() optim.zero_grad()
forward_backward_func(fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape) forward_backward_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape
)
optim.step() optim.step()
if __name__ == '__main__':
if __name__ == "__main__":
global fancy_data global fancy_data
global effective_length global effective_length
...@@ -128,13 +154,12 @@ if __name__ == '__main__': ...@@ -128,13 +154,12 @@ if __name__ == '__main__':
effective_length = fancy_data.size(0) // global_vars.get_args().seq_length effective_length = fancy_data.size(0) // global_vars.get_args().seq_length
effective_length = fancy_data.size(0) - global_vars.get_args().seq_length effective_length = fancy_data.size(0) - global_vars.get_args().seq_length
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
failure = None failure = None
try: try:
args = global_vars.get_args() args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt args.padded_vocab_size = 128 # needed in standalone gpt
batch_size = args.global_batch_size batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size micro_batch_size = args.micro_batch_size
setup_microbatch_calculator( setup_microbatch_calculator(
...@@ -147,27 +172,44 @@ if __name__ == '__main__': ...@@ -147,27 +172,44 @@ if __name__ == '__main__':
virtual_pipeline_model_parallel_size = 2 virtual_pipeline_model_parallel_size = 2
pipeline_model_parallel_size = world_size pipeline_model_parallel_size = world_size
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
args.tensor_model_parallel_size, args.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) args.tensor_model_parallel_size,
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
tensor_parallel.random.model_parallel_cuda_manual_seed(0) tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model( model = build_model(
bert_model_provider, bert_model_provider,
wrap_with_ddp=True, wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
cpu_offload=args.cpu_offload) cpu_offload=args.cpu_offload,
)
assert isinstance(model, list) assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size) assert len(model) == (
1
if virtual_pipeline_model_parallel_size is None
else virtual_pipeline_model_parallel_size
)
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups) optim = torch.optim.Adam(_param_groups)
print(effective_length) print(effective_length)
print(fancy_data.size(0)) print(fancy_data.size(0))
train(model, optim, virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_size) train(
model,
optim,
virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_size,
)
except Exception as e: except Exception as e:
failure = str(e) failure = str(e)
finally: finally:
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
if failure is not None: if failure is not None:
warnings.warn(f"Minimal BERT Pipeline Parallel Failed with: {failure}", DebugWarning) warnings.warn(
f"Minimal BERT Pipeline Parallel Failed with: {failure}", DebugWarning
)
print(f"Minimal BERT Pipeline Parallel Failed with: {failure}") print(f"Minimal BERT Pipeline Parallel Failed with: {failure}")
torch.distributed.barrier() torch.distributed.barrier()
print(TEST_SUCCESS_MESSAGE) print(TEST_SUCCESS_MESSAGE)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.commons import IdentityLayer
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
target.view(-1),
reduction='none').view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
logits_parallel_ = logits_parallel.clone().detach()
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
# check for mutation
assert torch.equal(logits_parallel_, logits_parallel)
return loss, identity.weight.grad
def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(' max error in loss on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(' max error in grad on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing broadcast_data with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + parallel_state.get_data_parallel_rank())
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
key_size_t = {
'key1': [7, 11],
'key2': [8, 2, 1],
'key3': [13],
'key4': [5, 1, 2],
'key5': [5, 12],
}
keys = list(key_size_t.keys())
data = {}
data_t = {}
for key in key_size_t:
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, key_numel, \
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
assert key_size[key] == key_size_t[key]
total_numel_t = 0
for key in keys:
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
assert key_numel[key] == target_size
total_numel_t += target_size
assert total_numel == total_numel_t
data_b = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
tensor = data_t[key].cuda()
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
...@@ -44,7 +44,13 @@ HIDDEN_SIZE = 16 ...@@ -44,7 +44,13 @@ HIDDEN_SIZE = 16
def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [(torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2)) for _ in range(num_samples)] return [
(
torch.randn(HIDDEN_SIZE, HIDDEN_SIZE),
torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2),
)
for _ in range(num_samples)
]
# Run forward & backward with dynamic batch size. # Run forward & backward with dynamic batch size.
...@@ -66,9 +72,13 @@ def run_interleaved_with_dynamic_batch_size( ...@@ -66,9 +72,13 @@ def run_interleaved_with_dynamic_batch_size(
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size
) )
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
print_separator(f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}") print_separator(
f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}"
)
model = build_model( model = build_model(
model_provider_func, model_provider_func,
...@@ -158,7 +168,10 @@ if __name__ == "__main__": ...@@ -158,7 +168,10 @@ if __name__ == "__main__":
args.micro_batch_size, args.micro_batch_size,
1, # args.data_parallel_size, 1, # args.data_parallel_size,
) )
for BatchSamplerCls in (MegatronPretrainingSampler, MegatronPretrainingRandomSampler): for BatchSamplerCls in (
MegatronPretrainingSampler,
MegatronPretrainingRandomSampler,
):
for forward_only in (False, True): for forward_only in (False, True):
n_tests += 1 n_tests += 1
pipeline_model_parallel_size = world_size pipeline_model_parallel_size = world_size
......
import torch from functools import partial
import os
from typing import List from typing import List
import time import time
from functools import partial
import torch
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.schedules.common import build_model from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import (
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving _get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.standalone_gpt import gpt_model_provider from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
...@@ -23,41 +28,45 @@ inds = None ...@@ -23,41 +28,45 @@ inds = None
data_idx = 0 data_idx = 0
N_VOCAB = 128 N_VOCAB = 128
def download_fancy_data(): def download_fancy_data():
#import requests # import requests
#response = requests.get('https://internet.com/book.txt') # response = requests.get('https://internet.com/book.txt')
#text = ' '.join(response.text.split()) # text = ' '.join(response.text.split())
text = """ text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
""" """
text = text*1024 text = text * 1024
encoded = text.encode('ascii', 'replace') encoded = text.encode("ascii", "replace")
ints = [int(encoded[i]) for i in range(len(encoded))] ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints) return torch.tensor(ints)
# build a batch given sequence_len and batch size # build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size): def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx global data_idx
global inds global inds
global MANUAL_SEED global MANUAL_SEED
temps = list() temps = list()
for i in range(batch_size): for i in range(batch_size):
if inds is None or data_idx >= len(inds): if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different # hack as use of RNG will fall out of sync due to pipelines being different
model_parallel_cuda_manual_seed(MANUAL_SEED) model_parallel_cuda_manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device='cuda') inds = torch.randperm(effective_length, device="cuda")
MANUAL_SEED += 1 MANUAL_SEED += 1
data_idx = 0 data_idx = 0
data_idx_ = data_idx data_idx_ = data_idx
offset = inds[data_idx_] offset = inds[data_idx_]
data_idx += 1 data_idx += 1
curr = fancy_data[offset:offset+sequence_len+1].clone().detach() curr = fancy_data[offset : offset + sequence_len + 1].clone().detach()
temps.append(curr) temps.append(curr)
temp = torch.stack(temps, dim=0).cuda() temp = torch.stack(temps, dim=0).cuda()
return temp return temp
easy_data = None easy_data = None
def get_batch(int_tensors: List[torch.Tensor]): def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0] data = int_tensors[0]
# Unpack. # Unpack.
...@@ -84,7 +93,7 @@ def loss_func(loss_mask, output_tensor): ...@@ -84,7 +93,7 @@ def loss_func(loss_mask, output_tensor):
# Reduce loss for logging. # Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]} return loss, {"lm loss": averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86 # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
...@@ -103,24 +112,31 @@ def train(model, optim, pipeline_model_parallel_size): ...@@ -103,24 +112,31 @@ def train(model, optim, pipeline_model_parallel_size):
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
runtime = 0 runtime = 0
#training loop # training loop
for i in range(3): for i in range(3):
since = time.time() since = time.time()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('begin iter', i) print("begin iter", i)
batch = [generate_fancy_data_labels(args.seq_length, args.global_batch_size) for _ in range(pipeline_model_parallel_size)] batch = [
if torch.distributed.get_rank() == 0: generate_fancy_data_labels(args.seq_length, args.global_batch_size)
print("finished making batch...") for _ in range(pipeline_model_parallel_size)
optim.zero_grad() ]
fwd_bwd_func(fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape) if torch.distributed.get_rank() == 0:
if torch.distributed.get_rank() == 0: print("finished making batch...")
print('finished forward step') optim.zero_grad()
optim.step() fwd_bwd_func(
if torch.distributed.get_rank() == 0: fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape
print('finished iter', i) )
runtime += time.time() - since if torch.distributed.get_rank() == 0:
return runtime/3.0 print("finished forward step")
if __name__ == '__main__': optim.step()
if torch.distributed.get_rank() == 0:
print("finished iter", i)
runtime += time.time() - since
return runtime / 3.0
if __name__ == "__main__":
global fancy_data global fancy_data
global effective_length global effective_length
...@@ -134,7 +150,6 @@ if __name__ == '__main__': ...@@ -134,7 +150,6 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
failure = None failure = None
args.padded_vocab_size = 128 args.padded_vocab_size = 128
batch_size = args.global_batch_size batch_size = args.global_batch_size
...@@ -148,16 +163,19 @@ if __name__ == '__main__': ...@@ -148,16 +163,19 @@ if __name__ == '__main__':
) )
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,\ tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size) pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
model_parallel_cuda_manual_seed(0) model_parallel_cuda_manual_seed(0)
model = build_model( model = build_model(
gpt_model_provider, gpt_model_provider,
wrap_with_ddp=True, wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None, virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload cpu_offload=args.cpu_offload,
) )
assert isinstance(model, list), model assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
tensor_model_parallel_size))
tensor_model_parallel_size_ = min(
tensor_model_parallel_size,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
assert parallel_state.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == parallel_state.get_tensor_model_parallel_world_size()
assert rank == parallel_state.get_tensor_model_parallel_rank()
check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == parallel_state.get_data_parallel_world_size()
assert rank == parallel_state.get_data_parallel_rank()
check(parallel_state.get_data_parallel_group(), world_size, rank)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
tensor_model_parallel_size_))
tensor_model_parallel_size = min(
tensor_model_parallel_size_,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
assert parallel_state.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is None
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_pipeline_model_parallel_split_rank():
pipeline_model_parallel_split_rank_ = 1
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_)
assert parallel_state.model_parallel_is_initialized()
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is pipeline_model_parallel_split_rank_
fake_split_rank = 7
parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank == fake_split_rank
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
print_separator('test pipeline model parallel split rank')
test_pipeline_model_parallel_split_rank()
tensor_model_parallel_size *= 2
This diff is collapsed.
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import initialize_distributed
global_vars.set_global_variables()
def test__reduce(args, tensor_model_parallel_size):
print("Testing reduction size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._reduce(torch.full((10, 10, 10, 10), (50))),
torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
)
parallel_state.destroy_model_parallel()
print("Passed!")
def test__split(args, tensor_model_parallel_size):
print("Testing splitting size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
listy = []
for i in range(tensor_model_parallel_size):
listy.append(torch.randn(10, 1))
x = torch.cat(tuple(listy), 1)
out = mappings._split(x)
assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()])
parallel_state.destroy_model_parallel()
print("Passed!")
def test__gather(args, tensor_model_parallel_size):
print("Testing gathering size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._gather(torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
torch.tensor(list(range(tensor_model_parallel_size))),
)
parallel_state.destroy_model_parallel()
print("Passed!")
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test__reduce(args, tensor_model_parallel_size)
test__split(args, tensor_model_parallel_size)
test__gather(args, tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print(">> passed the test :-)")
from functools import partial
import logging
from typing import List
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.log_util import get_transformer_logger, set_logging_level
set_logging_level(logging.NOTSET)
_logger = get_transformer_logger("megatron_gpt_pipeline_test")
global_vars.set_global_variables()
N_VOCAB = 8192
def generate_batch(batch_size, sequence_length):
size = batch_size, sequence_length + 1
int_tensor = torch.randint(low=0, high=N_VOCAB, size=size, dtype=torch.long).cuda()
return int_tensor,
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L44
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
tokens_ = data.long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
N_VOCAB, # tokenizer.eod,
False, # args.reset_position_ids,
False, # args.reset_attention_mask,
False, # args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
# TODO (mkozuki): Currently I'm seeing no attribute `word_embeddings` which looks weird.
def forward_step(batch, model):
"""Forward step."""
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=None, forward_only=False):
parallel_state.initialize_model_parallel(1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
model_parallel_cuda_manual_seed(42)
model = build_model(
gpt_model_provider, True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size)
_logger.debug("building model")
assert isinstance(model, list)
assert len(model) == (1 or virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups)
if parallel_state.is_pipeline_last_stage():
_logger.debug("checking `word_embeddings` existence")
for m in model:
assert hasattr(m, "word_embeddings")
args = global_vars.get_args()
if virtual_pipeline_model_parallel_size is None:
batch = generate_batch(args.global_batch_size, args.seq_length)
else:
batch = [generate_batch(args.global_batch_size, args.seq_length) for _ in range(virtual_pipeline_model_parallel_size)]
_logger.debug("preparing batch")
if virtual_pipeline_model_parallel_size is None:
fwd_bwd_func = forward_backward_pipelining_without_interleaving
else:
fwd_bwd_func = _forward_backward_pipelining_with_interleaving
_logger.debug(f"selecting forward_backward func: {fwd_bwd_func}")
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
_logger.debug(f"`tensor_shape`: {tensor_shape}")
fwd_bwd_func(forward_step, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
_logger.debug(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
initialize_distributed()
args = global_vars.get_args()
args.padded_vocab_size = N_VOCAB
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
update_num_microbatches(0, True)
print_separator("run GPT model")
try:
run_gpt(torch.distributed.get_world_size())
# TODO(mkozuki): handle exception correctly, but for now, lazily commenting out as
# this won't get kicked by CI
except Exception as e:
_logger.debug(str(e))
pass
finally:
parallel_state.destroy_model_parallel()
import itertools
from typing import Optional
import warnings
import torch
from torch.cuda.amp import GradScaler
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import model_provider_func
from apex.transformer.testing.commons import fwd_step_func
from apex.transformer.log_util import get_transformer_logger, set_logging_level
# set_logging_level("INFO")
_logger = get_transformer_logger("pipeline_parallel_test")
global_vars.set_global_variables()
batch_size, micro_batch_size = None, None
hidden_size = 16
fwd_bwd_functions = {
"no_pipelining": forward_backward_no_pipelining,
"no_interleaving": forward_backward_pipelining_without_interleaving,
"interleaving": _forward_backward_pipelining_with_interleaving,
}
# Run forward & backward for one minibatch.
def forward_backward_func_template(
args,
name: str,
forward_backward_func,
pipeline_model_parallel_size: int,
forward_only: bool,
dtype: torch.dtype,
grad_scaler: Optional[GradScaler],
deallocate_pipeline_outputs: bool,
data_parallel_size: int,
) -> None:
print_separator(
f"{name}, {dtype}, use grad_scaler: {grad_scaler is not None}, "
f"deallocate_pipeline_outputs: {deallocate_pipeline_outputs}, "
f"pipeline parallel size: {pipeline_model_parallel_size}, "
f"data parallel size: {data_parallel_size}"
)
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
if name == "no_pipelining":
# note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with
# pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
# tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
parallel_state.initialize_model_parallel(1, 1, None)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
parallel_state.get_data_parallel_world_size(),
)
else:
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel(
data_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
parallel_state.get_data_parallel_world_size(),
)
if virtual_pipeline_model_parallel_size is not None:
# Check the experimental warning message
get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
model = build_model(
model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=hidden_size,
)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups, lr=1e-4)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size, hidden_size]
batch = (torch.randn(tensor_shape).cuda(),)
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape,
dtype=dtype, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if not forward_only:
for m in model:
for p in m.parameters():
if p.grad is None:
raise RuntimeError("grad not found")
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
n_tests = 0
failures = []
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
dtypes = [torch.float32] + _get_autocast_dtypes()
for forward_only, name, dtype, deallocate_pipeline_outputs in itertools.product(
(True, False),
fwd_bwd_functions.keys(),
dtypes,
(True, False),
):
forward_backward_func = fwd_bwd_functions[name]
if name == "interleaving" and torch.cuda.device_count() <= 2:
warnings.warn(
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
continue
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
n_tests += 1
data_parallel_size = 2 if world_size >= 8 and world_size % 2 == 0 else 1
pipeline_model_parallel_size = world_size if world_size < 8 else world_size // 2
try:
forward_backward_func_template(
args,
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
dtype=dtype,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
data_parallel_size=data_parallel_size,
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
print_separator("TEST RESULT")
if failures:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("\n".join(failures))
msg = f"{len(failures)} / {n_tests} cases failed"
raise RuntimeError(msg)
else:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("### PASS!")
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print(
'> testing model parallel cuda manual seed with size {} ...'.format(
tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with tensor_parallel.random.get_cuda_rng_tracker().fork():
assert (
torch.cuda.initial_seed() ==
12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()
)
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
import torch
from apex.transformer.tensor_parallel import utils
def test_divide():
assert utils.divide(8, 4) == 2
def test_split_tensor_along_last_dim():
inputy = torch.randn((100, 100, 100))
splits = utils.split_tensor_along_last_dim(inputy, 10)
last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits])
assert torch.equal(last_dim_shapes, torch.full((10,), 10))
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
test_divide()
test_split_tensor_along_last_dim()
print(">> passed the test :-)")
...@@ -101,7 +101,7 @@ class TestBatchSamplerBehavior(unittest.TestCase): ...@@ -101,7 +101,7 @@ class TestBatchSamplerBehavior(unittest.TestCase):
samples2.append(batch) samples2.append(batch)
if i == 4 - 1: if i == 4 - 1:
break break
torch.testing.assert_allclose(torch.cat(samples), torch.cat(samples2)) torch.testing.assert_close(torch.cat(samples), torch.cat(samples2))
def test_split_batch(self): def test_split_batch(self):
...@@ -127,11 +127,6 @@ class TestBatchSamplerBehavior(unittest.TestCase): ...@@ -127,11 +127,6 @@ class TestBatchSamplerBehavior(unittest.TestCase):
global_batch_size = 16 global_batch_size = 16
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2) loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
batch = next(iter(loader)) batch = next(iter(loader))
# samples = None
# for i, batch in enumerate(loader):
# # samples = batch
# if i == 0:
# break
for _micro_batch_size in (1, 2, 4, 8): for _micro_batch_size in (1, 2, 4, 8):
microbatches = list(split_batch_into_microbatch( microbatches = list(split_batch_into_microbatch(
...@@ -139,8 +134,6 @@ class TestBatchSamplerBehavior(unittest.TestCase): ...@@ -139,8 +134,6 @@ class TestBatchSamplerBehavior(unittest.TestCase):
_micro_batch_size=_micro_batch_size, _micro_batch_size=_micro_batch_size,
_global_batch_size=global_batch_size, _global_batch_size=global_batch_size,
)) ))
# print(batch)
# print(microbatches)
self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size) self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
self.assertEqual(len(microbatches[0][0]), _micro_batch_size) self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
......
import logging
from typing import Tuple
import torch
import torch.nn.functional as F
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel import cross_entropy
from apex.transformer.testing.commons import set_random_seed, IdentityLayer
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
def torch_cross_entropy(
batch_size: int, seq_length: int, vocab_size: int, logits_scale: float, seed: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
set_random_seed(seed)
identity = IdentityLayer(
(batch_size, seq_length, vocab_size), scale=logits_scale
).cuda()
logits = identity()
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
loss = (
F.cross_entropy(
logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none"
)
.view_as(target)
.mean()
)
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(
batch_size, seq_length, vocab_size, logits_scale, seed
):
set_random_seed(seed)
identity = IdentityLayer(
(batch_size, seq_length, vocab_size), scale=logits_scale
).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
logits_parallel_ = logits_parallel.clone().detach()
loss = cross_entropy.vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
# check for mutation
assert torch.equal(logits_parallel_, logits_parallel)
return loss, identity.weight.grad
class VocabParallelCrossEntropy(DistributedTestBase):
def test_cross_entropy(self):
batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
logits_scale = 1000.0
seed = 1234
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size
loss_torch, grad_torch = torch_cross_entropy(
batch_size, sequence_length, vocab_size, logits_scale, seed
)
(
loss_tensor_parallel,
grad_tensor_parallel,
) = tensor_sharded_cross_entropy(
batch_size, sequence_length, vocab_size, logits_scale, seed
)
torch.testing.assert_close(loss_torch, loss_tensor_parallel)
torch.testing.assert_close(grad_torch, grad_tensor_parallel)
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch.testing
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
class BroadcastDataTest(DistributedTestBase):
def test_broadcast_data(self):
tensor_model_parallel_world_size: int = self.world_size // (
1 + self.world_size > 1
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
target_key_size = {
"key1": [7, 11],
"key2": [8, 2, 1],
"key3": [13],
"key4": [5, 1, 2],
"key5": [5, 12],
}
keys = [k for k in target_key_size]
data = {}
data_t = {}
with torch.no_grad():
for key in target_key_size:
data[key] = torch.randint(0, 1000, size=target_key_size[key])
data_t[key] = data[key].clone()
# "key_x" is supposed to be ignored.
data["key_x"] = torch.rand(5)
data_t["key_x"] = data["key_x"].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
self.assertEqual(target_key_size[key], key_size[key])
broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
torch.testing.assert_close(broadcasted_data[key], data_t[key].cuda())
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
...@@ -15,12 +15,20 @@ def attention_mask_func(attention_scores, attention_mask): ...@@ -15,12 +15,20 @@ def attention_mask_func(attention_scores, attention_mask):
return attention_scores.masked_fill(attention_mask, -10000.0) return attention_scores.masked_fill(attention_mask, -10000.0)
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) autocast_dtypes = (
(torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
class TestFusedScaleMaskSoftmax(unittest.TestCase): class TestFusedScaleMaskSoftmax(unittest.TestCase):
def _setup_fused_softmax(
def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding): self,
input_in_fp16,
input_in_bf16,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
):
fused_fn = FusedScaleMaskSoftmax( fused_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16, input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16, input_in_bf16=input_in_bf16,
...@@ -47,26 +55,40 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -47,26 +55,40 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
mask.shape = [4, 1, 24, 24] mask.shape = [4, 1, 24, 24]
""" """
for (dtype, scale, softmax_in_fp32) in itertools.product( for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16), (torch.half, torch.bfloat16), (None, 2.0), (False, True),
(None, 2.0),
(False, True),
): ):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16 input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32): if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding) self._setup_fused_softmax(
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.padding,
)
return return
fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding) fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16,
attention_scores_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True) input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.padding,
)
attention_scores_0 = (
torch.randn((4, 12, 24, 24))
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad(): with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True) attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool() mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
expected = fused_fn(attention_scores_0, mask) expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask) actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_allclose(actual, expected) torch.testing.assert_close(actual, expected)
g0 = torch.rand_like(actual) g0 = torch.rand_like(actual)
with torch.no_grad(): with torch.no_grad():
...@@ -80,18 +102,23 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -80,18 +102,23 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_fp16 = dtype == torch.half input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16 input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax( fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding) input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
)
attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) attention_scores_0 = (
torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
)
with torch.no_grad(): with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True) attention_scores_1 = (
attention_scores_0.clone().to(dtype).requires_grad_(True)
)
mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda() mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
expected = torch_fn(attention_scores_1, mask) expected = torch_fn(attention_scores_1, mask)
with torch.cuda.amp.autocast(dtype=dtype): with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores_0, mask) actual = fused_fn(attention_scores_0, mask)
self.assertEqual(actual.dtype, dtype) self.assertEqual(actual.dtype, dtype)
torch.testing.assert_allclose(actual, expected) torch.testing.assert_close(actual, expected)
g0 = torch.rand_like(actual) g0 = torch.rand_like(actual)
with torch.no_grad(): with torch.no_grad():
...@@ -108,9 +135,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -108,9 +135,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
upper elements are True and lower elements and diagonal are False. upper elements are True and lower elements and diagonal are False.
""" """
for (dtype, scale, softmax_in_fp32) in itertools.product( for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16), (torch.half, torch.bfloat16), (None, 2.0), (False, True),
(None, 2.0),
(False, True),
): ):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half input_in_fp16 = dtype == torch.half
...@@ -118,21 +143,37 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -118,21 +143,37 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
if not (scale is None or softmax_in_fp32): if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
self._setup_fused_softmax( self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal) input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.causal,
)
return return
fused_fn, torch_fn = self._setup_fused_softmax( fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal) input_in_fp16,
input_in_bf16,
attn_weights_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True) scale,
softmax_in_fp32,
AttnMaskType.causal,
)
attn_weights_0 = (
torch.randn((4, 12, 24, 24))
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad(): with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().requires_grad_(True) attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
total_mask = (~( total_mask = (
torch.tril(torch.randn((24, 24), device="cuda")).bool() ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
).unsqueeze(0).unsqueeze(0)) .unsqueeze(0)
.unsqueeze(0)
)
total_mask = total_mask.repeat((4, 1, 1, 1)) total_mask = total_mask.repeat((4, 1, 1, 1))
expected = fused_fn(attn_weights_0, total_mask) expected = fused_fn(attn_weights_0, total_mask)
actual = torch_fn(attn_weights_1, total_mask) actual = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected) torch.testing.assert_close(actual, expected)
g0 = torch.randn_like(actual) g0 = torch.randn_like(actual)
with torch.no_grad(): with torch.no_grad():
...@@ -146,20 +187,27 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase): ...@@ -146,20 +187,27 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_fp16 = dtype == torch.half input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16 input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax( fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal) input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
)
attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) attn_weights_0 = (
torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
)
with torch.no_grad(): with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True) attn_weights_1 = (
total_mask = (~( attn_weights_0.clone().to(dtype).requires_grad_(True)
torch.tril(torch.randn((24, 24), device="cuda")).bool() )
).unsqueeze(0).unsqueeze(0)) total_mask = (
~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
.unsqueeze(0)
.unsqueeze(0)
)
with torch.cuda.amp.autocast(dtype=dtype): with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attn_weights_0, total_mask) actual = fused_fn(attn_weights_0, total_mask)
self.assertEqual(actual.dtype, dtype) self.assertEqual(actual.dtype, dtype)
expected = torch_fn(attn_weights_1, total_mask) expected = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected) torch.testing.assert_close(actual, expected)
g0 = torch.randn_like(actual) g0 = torch.randn_like(actual)
with torch.no_grad(): with torch.no_grad():
......
import logging
import torch
import torch.nn as nn
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
# N.B. (mkozuki): Disable TF32 matrix multiply.
# Matrices used in this test are so small that TF32 matmul
# can be less precise so that `self.assertEqual` raises.
torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTest(DistributedTestBase):
BATCH_SIZE: int = 17
SEQUENCE_LENGTH: int = 23
VOCAB_SIZE: int = 48
HIDDEN_SIZE: int = 16
INPUT_SIZE_COEFF: int = 13
OUTPUT_SIZE_COEFF: int = 17
SEED: int = 123
def test_parallel_embedding(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(TensorParallelLayerTest.SEED + 1)
input_tensor = torch.randint(
0,
TensorParallelLayerTest.VOCAB_SIZE,
(
TensorParallelLayerTest.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH,
),
device="cuda",
)
loss_weight = torch.randn(
(
TensorParallelLayerTest.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH,
TensorParallelLayerTest.HIDDEN_SIZE,
),
device="cuda",
)
set_random_seed(TensorParallelLayerTest.SEED)
embedding_torch = nn.Embedding(
TensorParallelLayerTest.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE,
).cuda()
output_torch = embedding_torch(input_tensor)
loss_torch = torch.mul(output_torch, loss_weight).sum()
loss_torch.backward()
# N.B. (mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`.
set_random_seed(TensorParallelLayerTest.SEED)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
TensorParallelLayerTest.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE,
init_method=nn.init.normal_,
use_cpu_initialization=True,
).cuda()
output_vocab_parallel = embedding_vocab_parallel(input_tensor)
loss_vocab_parallel = torch.mul(
output_vocab_parallel, loss_weight
).sum()
loss_vocab_parallel.backward()
self.assertEqual(output_torch, output_vocab_parallel)
self.assertEqual(loss_torch, loss_vocab_parallel)
splitted_weight_torch = torch.split(
embedding_torch.weight.grad,
TensorParallelLayerTest.VOCAB_SIZE
// tensor_model_parallel_world_size,
0,
)[parallel_state.get_tensor_model_parallel_rank()]
self.assertEqual(
splitted_weight_torch, embedding_vocab_parallel.weight.grad
)
parallel_state.destroy_model_parallel()
def _affine_weight_init_test_impl(
self, init_device: str, is_column_parallel: bool
) -> None:
dim = int(not is_column_parallel)
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
weight_shape = (
(TensorParallelLayerTest.OUTPUT_SIZE_COEFF, input_size)
if is_column_parallel
else (output_size, TensorParallelLayerTest.INPUT_SIZE_COEFF)
)
weight = torch.empty(weight_shape)
set_random_seed(TensorParallelLayerTest.SEED)
sharding_dim_size = (
TensorParallelLayerTest.OUTPUT_SIZE_COEFF
if is_column_parallel
else TensorParallelLayerTest.INPUT_SIZE_COEFF
)
if init_device == "cpu":
layers._initialize_affine_weight_cpu(
weight,
output_size,
input_size,
sharding_dim_size,
dim,
nn.init.normal_,
params_dtype=torch.float32,
)
else:
layers._initialize_affine_weight_gpu(
weight, torch.nn.init.normal_, dim
)
# Target
set_random_seed(TensorParallelLayerTest.SEED)
if init_device == "cpu":
main_weight = torch.empty(output_size, input_size)
nn.init.normal_(main_weight)
curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[
parallel_state.get_tensor_model_parallel_rank()
]
else:
curr_weight = torch.empty(*weight_shape)
nn.init.normal_(curr_weight)
self.assertEqual(curr_weight, weight)
parallel_state.destroy_model_parallel()
def test_affine_weight_init_column_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True)
def test_affine_weight_init_column_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True)
def test_affine_weight_init_row_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False)
def test_affine_weight_init_row_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)
def test_row_parallel_linear(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
set_random_seed(TensorParallelLayerTest.SEED)
linear_layer = layers.RowParallelLinear(
input_size,
output_size,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
).cuda()
loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, output_size)
).cuda()
# Forward and backward
input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE, input_size, requires_grad=True
).cuda()
input_tensor.retain_grad()
output, _ = linear_layer(input_tensor)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
self.assertIsNotNone(input_tensor.grad)
with torch.no_grad():
dldy = loss_weight.clone()
x = input_tensor.clone()
a = linear_layer.master_weight.cuda()
dlda = torch.matmul(dldy.t(), x)
dldb = torch.matmul(
torch.ones(TensorParallelLayerTest.BATCH_SIZE, 1).cuda().t(), dldy
).view(-1)
dldx = torch.matmul(dldy, a)
with torch.no_grad():
curr_dlda = torch.split(
dlda, TensorParallelLayerTest.INPUT_SIZE_COEFF, dim=1
)[parallel_state.get_tensor_model_parallel_rank()].clone()
self.assertEqual(linear_layer.weight.grad, curr_dlda)
self.assertEqual(input_tensor.grad, dldx)
self.assertEqual(linear_layer.bias.grad, dldb)
parallel_state.destroy_model_parallel()
def test_column_parallel_linear(self):
self._column_parallel_linear_test_impl(False, False)
def test_column_parallel_linear_no_async(self):
self._column_parallel_linear_test_impl(True, False)
def test_column_parallel_linear_gradient_accumulation_fusion(self):
self._column_parallel_linear_test_impl(False, True)
def _column_parallel_linear_test_impl(
self,
no_async_tensor_model_parallel_allreduce: bool,
gradient_accumulation_fusion: bool,
):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
print(
f"tensor_model_parallel_world_size={tensor_model_parallel_world_size}"
)
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
feature_size_coeff = TensorParallelLayerTest.INPUT_SIZE_COEFF
feature_size = feature_size_coeff * tensor_model_parallel_world_size
hidden_size = feature_size
set_random_seed(TensorParallelLayerTest.SEED)
input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE,
hidden_size,
feature_size,
device="cuda",
requires_grad=True,
)
input_tensor.retain_grad()
loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,),
device="cuda",
)
linear = layers.ColumnParallelLinear(
feature_size,
feature_size,
bias=False,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
gradient_accumulation_fusion=gradient_accumulation_fusion,
).cuda()
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.randn_like(linear.weight)
output, _ = linear(input_tensor)
self.assertEqual(
output.shape,
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,),
)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
with torch.no_grad():
dldy = loss_weight.clone()
x = input_tensor.clone()
a = linear.master_weight.cuda().clone()
dldx = torch.matmul(dldy, a)
self.assertEqual(input_tensor.grad, dldx)
# TODO (mkozuki): Cover the other cases.
if (
tensor_model_parallel_world_size == 1
and not gradient_accumulation_fusion
):
dlda = torch.matmul(torch.transpose(dldy, 1, 2), x).sum(dim=0)
curr_dlda = torch.split(dlda, feature_size_coeff, dim=0)[
parallel_state.get_tensor_model_parallel_rank()
]
self.assertEqual(linear.weight.grad, curr_dlda)
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class MappingTest(DistributedTestBase):
def test_reduce(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
expected = torch.full(
(10, 10, 10, 10),
50 * tensor_model_paralell_world_size,
device=f"cuda:{self.rank}",
)
self.assertTrue(torch.equal(mappings._reduce(t), expected))
parallel_state.destroy_model_parallel()
def test_split(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
tensors = [
torch.randn(10, 1)
for rank in range(tensor_model_paralell_world_size)
]
x = torch.cat(tensors, 1)
out = mappings._split(x)
self.assertTrue(
torch.equal(
out, tensors[parallel_state.get_tensor_model_parallel_rank()]
)
)
parallel_state.destroy_model_parallel()
def test_gather(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = f"cuda:{self.rank}"
gathered = mappings._gather(
torch.tensor(
[parallel_state.get_tensor_model_parallel_rank()], device=device
)
)
expected = torch.tensor(
[rank for rank in range(tensor_model_paralell_world_size)],
device=device,
)
self.assertTrue(torch.equal(gathered, expected))
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import os
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
os.environ["BACKEND"] = "NCCL"
DATA_PARALLEL_WORLD_SIZE: int = 1
def calc_expected_tensor_model_paralell_rank(
rank: int, tensor_model_parallel_world_size: int,
) -> int:
return rank % tensor_model_parallel_world_size
class ParallelStateTest(DistributedTestBase):
def test_initialize_model_parallel(self) -> None:
self.assertFalse(parallel_state.model_parallel_is_initialized())
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
)
self.assertEqual(
tensor_model_parallel_world_size,
parallel_state.get_tensor_model_parallel_world_size(),
)
expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
)
self.assertEqual(
expected_tensor_model_parallel_rank,
parallel_state.get_tensor_model_parallel_rank(),
)
expected_tensor_model_parallel_src_rank = (
self.rank // tensor_model_parallel_world_size
) * tensor_model_parallel_world_size
self.assertEqual(
expected_tensor_model_parallel_src_rank,
parallel_state.get_tensor_model_parallel_src_rank(),
)
parallel_state.destroy_model_parallel()
self.assertFalse(parallel_state.model_parallel_is_initialized())
def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
if self.world_size < 4:
self.skipTest("requires >= 4 GPUs")
self.assertFalse(parallel_state.model_parallel_is_initialized())
tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
virtual_pipeline_model_parallel_world_size = 2
pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size,
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
)
self.assertEqual(
calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
),
parallel_state.get_tensor_model_parallel_rank(),
)
self.assertEqual(
pipeline_model_parallel_world_size,
parallel_state.get_pipeline_model_parallel_world_size(),
)
self.assertEqual(
virtual_pipeline_model_parallel_world_size,
parallel_state.get_virtual_pipeline_model_parallel_world_size(),
)
expected_pipeline_rank = (
self.rank - (self.rank % tensor_model_parallel_world_size)
) % pipeline_model_parallel_world_size
self.assertEqual(
expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(),
)
# virtual pipeline model parallel rank is lazily set, i.e., right after the call of
# `initialize_model_parallel`, it's set to 0.
self.assertEqual(
0, parallel_state.get_virtual_pipeline_model_parallel_rank(),
)
self.assertEqual(
pipeline_model_parallel_split_rank,
parallel_state.get_pipeline_model_parallel_split_rank(),
)
fake_split_rank = 77
parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
self.assertEqual(
fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()
)
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import itertools
from typing import Optional
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import utils as pp_utils
from apex.transformer.pipeline_parallel.schedules.common import (
FwdStepFunc,
build_model,
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing import commons as testing_utils
logging.getLogger("apex").setLevel(logging.WARNING)
class PipelineParallelForwardBackwardTest(DistributedTestBase):
GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 1
HIDDEN_SIZE = 32
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _forward_backward_test_impl(
self,
forward_only: bool,
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
vriatual_pipeline_model_parallel_size: Optional[int],
) -> None:
for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False),
):
grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0)
if dtype == torch.half
else None
)
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (self.world_size >= 8 and self.world_size % 2 == 0)
pipeline_model_parallel_world_size = (
self.world_size
// (tensor_model_parallel_world_size * data_parallel_size)
if pipeline_model_parallel_world_size is None
else 1
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=vriatual_pipeline_model_parallel_size,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE,
micro_batch_size=PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
global_batch_shape = (
PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(),
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
)
batch = (torch.randn(global_batch_shape).cuda(),)
model = build_model(
testing_utils.model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=vriatual_pipeline_model_parallel_size,
hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
)
_param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
pp_utils.update_num_microbatches(0)
fwd_bwd_func(
testing_utils.fwd_step_func,
batch,
model,
forward_only=forward_only,
# `tensor_shape` is the shape of micro batch.
tensor_shape=(
PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
),
dtype=dtype,
grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
if not forward_only:
for m in model:
for p in m.parameters():
self.assertIsNotNone(p.grad)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
parallel_state.destroy_model_parallel()
def test_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)
def test_no_pipelining_inference(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_inference(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, 2, None
)
def test_pipelining_with_interleaving_inference(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, 2, None
)
if __name__ == "__main__":
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment