Unverified Commit a0ed4151 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Format & Test Refactoring (#1325)

* try PyTorch custom TestCase class

* revert

* initial working example

* update

* data utils

* fix imports

* hardcode backend to nccl

* fix signature

* fix typo

* mapping

* set device

* init

* refactor x entropy

* remove unused import & destroy model parallel

* refactor random

* fix test

* remove migrated tests

* refactor

* init

* separate affine weight init

* init model parallel

* split more

* weight init fix part 1

* use cpu init for consistency btwn native and tensor parallel

* black

* add col parallel

* use a 3D tensor of square matrix for column parallel linear

* skip the failing cases

* migrate layers test

* pipeline parallel forward/backward

* fix typo

* fix typo

* fix

* fix pipeline world size

* black

* rm `run_pipeline_parallel_test` in favor of test_pipeline_parallel_fwd_bwd.py

* stop logging

* set log level

* black

* license and format

* fix

* skip tf32 as matrices are small

* remove potentially inappropriate license

* Apply suggestions from code review

* remove `TODO` comment

* `torch.testing.assert_allclose` -> `torch.testing.assert_close`

* remove comment-outs

* remote unused import

* minor fix
parent f10b4b89
......@@ -5,10 +5,14 @@ from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.testing import global_vars
......@@ -17,9 +21,12 @@ from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
import warnings
class DebugWarning(Warning):
pass
mode = None
MANUAL_SEED = 42
inds = None
......@@ -30,62 +37,74 @@ EASY_MODE = False
EASY_MODE_SIZ = 32
ONCE = False
def download_fancy_data():
#import requests
#response = requests.get('https://internet.com/book.txt')
#text = ' '.join(response.text.split())
text = """
# import requests
# response = requests.get('https://internet.com/book.txt')
# text = ' '.join(response.text.split())
text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
"""
text = text*1024
encoded = text.encode('ascii', 'replace')
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
text = text * 1024
encoded = text.encode("ascii", "replace")
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
# build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx
global inds
global masks
global MANUAL_SEED
temps = []
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
torch.manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device='cuda')
masks = (torch.rand(len(inds)//batch_size + 1, batch_size, sequence_len, device='cuda') >= MASK_PROB).long()
MANUAL_SEED += 1
print("new epoch", len(inds))
data_idx = 0
print("my start", inds[0:5])
print("masks_checksum:", torch.sum(masks))
if EASY_MODE:
data_idx_ = data_idx % EASY_MODE_SIZ
else:
data_idx_ = data_idx
offset = inds[data_idx_] #* SEQUENCE_LEN
data_idx += 1
curr = fancy_data[offset:offset+sequence_len].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
mask = masks[data_idx//batch_size]
mask_not = torch.logical_not(mask).long()
data = mask * temp + mask_not*124
label = temp
if parallel_state.get_tensor_model_parallel_rank() == 0:
data_dict = {"text": data, "label": label, "mask_not": mask_not}
else:
data_dict = None
keys = ["text", "label", "mask_not"]
dtype = torch.int64
broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long)
return (broadcasted_data["text"].long(), broadcasted_data["label"].long(), broadcasted_data["mask_not"])
global data_idx
global inds
global masks
global MANUAL_SEED
temps = []
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
torch.manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device="cuda")
masks = (
torch.rand(
len(inds) // batch_size + 1, batch_size, sequence_len, device="cuda"
)
>= MASK_PROB
).long()
MANUAL_SEED += 1
print("new epoch", len(inds))
data_idx = 0
print("my start", inds[0:5])
print("masks_checksum:", torch.sum(masks))
if EASY_MODE:
data_idx_ = data_idx % EASY_MODE_SIZ
else:
data_idx_ = data_idx
offset = inds[data_idx_] # * SEQUENCE_LEN
data_idx += 1
curr = fancy_data[offset : offset + sequence_len].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
mask = masks[data_idx // batch_size]
mask_not = torch.logical_not(mask).long()
data = mask * temp + mask_not * 124
label = temp
if parallel_state.get_tensor_model_parallel_rank() == 0:
data_dict = {"text": data, "label": label, "mask_not": mask_not}
else:
data_dict = None
keys = ["text", "label", "mask_not"]
dtype = torch.int64
broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long)
return (
broadcasted_data["text"].long(),
broadcasted_data["label"].long(),
broadcasted_data["mask_not"],
)
easy_data = None
def fwd_step_func(batch, model):
data, label, loss_mask = batch
y = model(data, torch.ones_like(data), lm_labels=label)
......@@ -94,31 +113,38 @@ def fwd_step_func(batch, model):
global ONCE
output_tensor, _ = output_tensor
lm_loss_ = output_tensor.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
averaged_loss = average_losses_across_data_parallel_group([lm_loss])
if data_idx >= 1536:
assert lm_loss < 4.8
if not ONCE:
print("LOSS OK")
ONCE = True
return lm_loss, {'avg': averaged_loss}
return lm_loss, {"avg": averaged_loss}
return y, loss_func
def train(model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size):
def train(
model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
forward_backward_func = get_forward_backward_func(
virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
)
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
for _ in range(16):
batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad()
forward_backward_func(fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape)
forward_backward_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape
)
optim.step()
if __name__ == '__main__':
if __name__ == "__main__":
global fancy_data
global effective_length
......@@ -128,13 +154,12 @@ if __name__ == '__main__':
effective_length = fancy_data.size(0) // global_vars.get_args().seq_length
effective_length = fancy_data.size(0) - global_vars.get_args().seq_length
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
try:
args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt
args.padded_vocab_size = 128 # needed in standalone gpt
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
......@@ -147,27 +172,44 @@ if __name__ == '__main__':
virtual_pipeline_model_parallel_size = 2
pipeline_model_parallel_size = world_size
parallel_state.initialize_model_parallel(
args.tensor_model_parallel_size, args.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model(
bert_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
cpu_offload=args.cpu_offload)
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
assert len(model) == (
1
if virtual_pipeline_model_parallel_size is None
else virtual_pipeline_model_parallel_size
)
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
print(effective_length)
print(fancy_data.size(0))
train(model, optim, virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_size)
train(
model,
optim,
virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_size,
)
except Exception as e:
failure = str(e)
finally:
parallel_state.destroy_model_parallel()
if failure is not None:
warnings.warn(f"Minimal BERT Pipeline Parallel Failed with: {failure}", DebugWarning)
warnings.warn(
f"Minimal BERT Pipeline Parallel Failed with: {failure}", DebugWarning
)
print(f"Minimal BERT Pipeline Parallel Failed with: {failure}")
torch.distributed.barrier()
print(TEST_SUCCESS_MESSAGE)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.commons import IdentityLayer
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
target.view(-1),
reduction='none').view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
logits_parallel_ = logits_parallel.clone().detach()
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
# check for mutation
assert torch.equal(logits_parallel_, logits_parallel)
return loss, identity.weight.grad
def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(' max error in loss on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(' max error in grad on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing broadcast_data with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + parallel_state.get_data_parallel_rank())
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
key_size_t = {
'key1': [7, 11],
'key2': [8, 2, 1],
'key3': [13],
'key4': [5, 1, 2],
'key5': [5, 12],
}
keys = list(key_size_t.keys())
data = {}
data_t = {}
for key in key_size_t:
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, key_numel, \
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
assert key_size[key] == key_size_t[key]
total_numel_t = 0
for key in keys:
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
assert key_numel[key] == target_size
total_numel_t += target_size
assert total_numel == total_numel_t
data_b = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
tensor = data_t[key].cuda()
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -44,7 +44,13 @@ HIDDEN_SIZE = 16
def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [(torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2)) for _ in range(num_samples)]
return [
(
torch.randn(HIDDEN_SIZE, HIDDEN_SIZE),
torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2),
)
for _ in range(num_samples)
]
# Run forward & backward with dynamic batch size.
......@@ -66,9 +72,13 @@ def run_interleaved_with_dynamic_batch_size(
parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size
)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
print_separator(f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}")
print_separator(
f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}"
)
model = build_model(
model_provider_func,
......@@ -158,7 +168,10 @@ if __name__ == "__main__":
args.micro_batch_size,
1, # args.data_parallel_size,
)
for BatchSamplerCls in (MegatronPretrainingSampler, MegatronPretrainingRandomSampler):
for BatchSamplerCls in (
MegatronPretrainingSampler,
MegatronPretrainingRandomSampler,
):
for forward_only in (False, True):
n_tests += 1
pipeline_model_parallel_size = world_size
......
import torch
import os
from functools import partial
from typing import List
import time
from functools import partial
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
......@@ -23,41 +28,45 @@ inds = None
data_idx = 0
N_VOCAB = 128
def download_fancy_data():
#import requests
#response = requests.get('https://internet.com/book.txt')
#text = ' '.join(response.text.split())
text = """
# import requests
# response = requests.get('https://internet.com/book.txt')
# text = ' '.join(response.text.split())
text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
"""
text = text*1024
encoded = text.encode('ascii', 'replace')
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
text = text * 1024
encoded = text.encode("ascii", "replace")
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
# build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx
global inds
global MANUAL_SEED
temps = list()
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
model_parallel_cuda_manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device='cuda')
MANUAL_SEED += 1
data_idx = 0
data_idx_ = data_idx
offset = inds[data_idx_]
data_idx += 1
curr = fancy_data[offset:offset+sequence_len+1].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
return temp
global data_idx
global inds
global MANUAL_SEED
temps = list()
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
model_parallel_cuda_manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device="cuda")
MANUAL_SEED += 1
data_idx = 0
data_idx_ = data_idx
offset = inds[data_idx_]
data_idx += 1
curr = fancy_data[offset : offset + sequence_len + 1].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
return temp
easy_data = None
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
......@@ -84,7 +93,7 @@ def loss_func(loss_mask, output_tensor):
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
return loss, {"lm loss": averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
......@@ -103,24 +112,31 @@ def train(model, optim, pipeline_model_parallel_size):
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
runtime = 0
#training loop
# training loop
for i in range(3):
since = time.time()
if torch.distributed.get_rank() == 0:
print('begin iter', i)
batch = [generate_fancy_data_labels(args.seq_length, args.global_batch_size) for _ in range(pipeline_model_parallel_size)]
if torch.distributed.get_rank() == 0:
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape)
if torch.distributed.get_rank() == 0:
print('finished forward step')
optim.step()
if torch.distributed.get_rank() == 0:
print('finished iter', i)
runtime += time.time() - since
return runtime/3.0
if __name__ == '__main__':
since = time.time()
if torch.distributed.get_rank() == 0:
print("begin iter", i)
batch = [
generate_fancy_data_labels(args.seq_length, args.global_batch_size)
for _ in range(pipeline_model_parallel_size)
]
if torch.distributed.get_rank() == 0:
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape
)
if torch.distributed.get_rank() == 0:
print("finished forward step")
optim.step()
if torch.distributed.get_rank() == 0:
print("finished iter", i)
runtime += time.time() - since
return runtime / 3.0
if __name__ == "__main__":
global fancy_data
global effective_length
......@@ -134,7 +150,6 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
args.padded_vocab_size = 128
batch_size = args.global_batch_size
......@@ -148,16 +163,19 @@ if __name__ == '__main__':
)
world_size = torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,\
pipeline_model_parallel_size_=args.pipeline_model_parallel_size)
tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model)
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
tensor_model_parallel_size))
tensor_model_parallel_size_ = min(
tensor_model_parallel_size,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
assert parallel_state.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == parallel_state.get_tensor_model_parallel_world_size()
assert rank == parallel_state.get_tensor_model_parallel_rank()
check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == parallel_state.get_data_parallel_world_size()
assert rank == parallel_state.get_data_parallel_rank()
check(parallel_state.get_data_parallel_group(), world_size, rank)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
tensor_model_parallel_size_))
tensor_model_parallel_size = min(
tensor_model_parallel_size_,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
assert parallel_state.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is None
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_pipeline_model_parallel_split_rank():
pipeline_model_parallel_split_rank_ = 1
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_)
assert parallel_state.model_parallel_is_initialized()
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is pipeline_model_parallel_split_rank_
fake_split_rank = 7
parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank == fake_split_rank
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
print_separator('test pipeline model parallel split rank')
test_pipeline_model_parallel_split_rank()
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // tensor_model_parallel_size,
1)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // tensor_model_parallel_size,
0)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(tensor_model_parallel_size, device):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype,
)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(
weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m, n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
hidden_size = 9
# Network
gradient_accumulation_fusion = True
identity_layer = IdentityLayer3D(batch_size, hidden_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
gradient_accumulation_fusion=gradient_accumulation_fusion,
).cuda()
with torch.no_grad():
linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight)
loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
assert list(output.shape) == [batch_size, hidden_size, output_size]
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# TODO (mkozuki): Fix the following commented out lines
# as `gradient_accumulation_fusion` only takes 3D tensors.
# Values.
# dLdY = loss_weight # (7, 9, 17)
# X = identity_layer.weight # (7, 9, 13)
# A = linear_layer.master_weight.cuda() # (17, 13)
# print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}")
# dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13))
# print(f"dLdA.shape = {dLdA.shape}")
# ones = torch.ones(batch_size, hidden_size, 1).cuda()
# print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}")
# dLdb = torch.matmul(ones, dLdY).view(-1)
# dLdX = torch.matmul(dLdY, A)
# rank = parallel_state.get_tensor_model_parallel_rank()
# my_dLdA = torch.split(dLdA, output_size_coeff,
# dim=0)[rank].contiguous().clone()
# error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdA on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# my_dLdb = torch.split(dLdb, output_size_coeff,
# dim=0)[rank].contiguous().clone()
# error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdb on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# error = dLdX.sub(identity_layer.weight.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdX on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size):
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
assert linear_layer.async_tensor_model_parallel_allreduce or tensor_model_parallel_size == 1
# Forward
for dtype in autocast_dtypes:
loss_weight = torch.randn([batch_size, output_size]).cuda()
with torch.cuda.amp.autocast(dtype=dtype):
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
assert output.dtype == dtype
# Backward
loss.backward()
torch.distributed.barrier()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size):
dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
for dtype in dtypes:
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype)
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).to(device="cuda", dtype=dtype)
# Forward
loss_weight = torch.randn([batch_size, output_size]).cuda()
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
loss.backward()
torch.distributed.barrier()
assert output.dtype == dtype
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = parallel_state.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = parallel_state.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
exceptions = []
print_separator('test initialize affine weight cpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'cpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-cpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Reset groups
parallel_state.destroy_model_parallel()
print_separator('test initialize affine weight gpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'gpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-gpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Deleted, replaced with vocab parallel embedding?
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# print_separator('test parallel embedding')
# test_parallel_embedding(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
print_separator('test column-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_row_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_row_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - autocast")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_autocast with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - custom AMP")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_custom_amp with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
if exceptions:
raise RuntimeError("\n".join(exceptions))
# Deleted
#print_separator('test parallel self-attention')
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# test_parallel_self_attention(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
#Deleted because PararallelTransformerLayer no longer exists
# print_separator('test parallel transformer')
# tensor_model_parallel_size = 1
# while tensor_model_parallel_size <= world_size:
# test_parallel_transformer_layer(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import initialize_distributed
global_vars.set_global_variables()
def test__reduce(args, tensor_model_parallel_size):
print("Testing reduction size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._reduce(torch.full((10, 10, 10, 10), (50))),
torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
)
parallel_state.destroy_model_parallel()
print("Passed!")
def test__split(args, tensor_model_parallel_size):
print("Testing splitting size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
listy = []
for i in range(tensor_model_parallel_size):
listy.append(torch.randn(10, 1))
x = torch.cat(tuple(listy), 1)
out = mappings._split(x)
assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()])
parallel_state.destroy_model_parallel()
print("Passed!")
def test__gather(args, tensor_model_parallel_size):
print("Testing gathering size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._gather(torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
torch.tensor(list(range(tensor_model_parallel_size))),
)
parallel_state.destroy_model_parallel()
print("Passed!")
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test__reduce(args, tensor_model_parallel_size)
test__split(args, tensor_model_parallel_size)
test__gather(args, tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print(">> passed the test :-)")
from functools import partial
import logging
from typing import List
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.log_util import get_transformer_logger, set_logging_level
set_logging_level(logging.NOTSET)
_logger = get_transformer_logger("megatron_gpt_pipeline_test")
global_vars.set_global_variables()
N_VOCAB = 8192
def generate_batch(batch_size, sequence_length):
size = batch_size, sequence_length + 1
int_tensor = torch.randint(low=0, high=N_VOCAB, size=size, dtype=torch.long).cuda()
return int_tensor,
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L44
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
tokens_ = data.long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
N_VOCAB, # tokenizer.eod,
False, # args.reset_position_ids,
False, # args.reset_attention_mask,
False, # args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
# TODO (mkozuki): Currently I'm seeing no attribute `word_embeddings` which looks weird.
def forward_step(batch, model):
"""Forward step."""
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=None, forward_only=False):
parallel_state.initialize_model_parallel(1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
model_parallel_cuda_manual_seed(42)
model = build_model(
gpt_model_provider, True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size)
_logger.debug("building model")
assert isinstance(model, list)
assert len(model) == (1 or virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups)
if parallel_state.is_pipeline_last_stage():
_logger.debug("checking `word_embeddings` existence")
for m in model:
assert hasattr(m, "word_embeddings")
args = global_vars.get_args()
if virtual_pipeline_model_parallel_size is None:
batch = generate_batch(args.global_batch_size, args.seq_length)
else:
batch = [generate_batch(args.global_batch_size, args.seq_length) for _ in range(virtual_pipeline_model_parallel_size)]
_logger.debug("preparing batch")
if virtual_pipeline_model_parallel_size is None:
fwd_bwd_func = forward_backward_pipelining_without_interleaving
else:
fwd_bwd_func = _forward_backward_pipelining_with_interleaving
_logger.debug(f"selecting forward_backward func: {fwd_bwd_func}")
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
_logger.debug(f"`tensor_shape`: {tensor_shape}")
fwd_bwd_func(forward_step, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
_logger.debug(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
initialize_distributed()
args = global_vars.get_args()
args.padded_vocab_size = N_VOCAB
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
update_num_microbatches(0, True)
print_separator("run GPT model")
try:
run_gpt(torch.distributed.get_world_size())
# TODO(mkozuki): handle exception correctly, but for now, lazily commenting out as
# this won't get kicked by CI
except Exception as e:
_logger.debug(str(e))
pass
finally:
parallel_state.destroy_model_parallel()
import itertools
from typing import Optional
import warnings
import torch
from torch.cuda.amp import GradScaler
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import model_provider_func
from apex.transformer.testing.commons import fwd_step_func
from apex.transformer.log_util import get_transformer_logger, set_logging_level
# set_logging_level("INFO")
_logger = get_transformer_logger("pipeline_parallel_test")
global_vars.set_global_variables()
batch_size, micro_batch_size = None, None
hidden_size = 16
fwd_bwd_functions = {
"no_pipelining": forward_backward_no_pipelining,
"no_interleaving": forward_backward_pipelining_without_interleaving,
"interleaving": _forward_backward_pipelining_with_interleaving,
}
# Run forward & backward for one minibatch.
def forward_backward_func_template(
args,
name: str,
forward_backward_func,
pipeline_model_parallel_size: int,
forward_only: bool,
dtype: torch.dtype,
grad_scaler: Optional[GradScaler],
deallocate_pipeline_outputs: bool,
data_parallel_size: int,
) -> None:
print_separator(
f"{name}, {dtype}, use grad_scaler: {grad_scaler is not None}, "
f"deallocate_pipeline_outputs: {deallocate_pipeline_outputs}, "
f"pipeline parallel size: {pipeline_model_parallel_size}, "
f"data parallel size: {data_parallel_size}"
)
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
if name == "no_pipelining":
# note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with
# pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
# tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
parallel_state.initialize_model_parallel(1, 1, None)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
parallel_state.get_data_parallel_world_size(),
)
else:
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel(
data_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
parallel_state.get_data_parallel_world_size(),
)
if virtual_pipeline_model_parallel_size is not None:
# Check the experimental warning message
get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
model = build_model(
model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=hidden_size,
)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups, lr=1e-4)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size, hidden_size]
batch = (torch.randn(tensor_shape).cuda(),)
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape,
dtype=dtype, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if not forward_only:
for m in model:
for p in m.parameters():
if p.grad is None:
raise RuntimeError("grad not found")
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
n_tests = 0
failures = []
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
dtypes = [torch.float32] + _get_autocast_dtypes()
for forward_only, name, dtype, deallocate_pipeline_outputs in itertools.product(
(True, False),
fwd_bwd_functions.keys(),
dtypes,
(True, False),
):
forward_backward_func = fwd_bwd_functions[name]
if name == "interleaving" and torch.cuda.device_count() <= 2:
warnings.warn(
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
continue
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
n_tests += 1
data_parallel_size = 2 if world_size >= 8 and world_size % 2 == 0 else 1
pipeline_model_parallel_size = world_size if world_size < 8 else world_size // 2
try:
forward_backward_func_template(
args,
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
dtype=dtype,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
data_parallel_size=data_parallel_size,
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
print_separator("TEST RESULT")
if failures:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("\n".join(failures))
msg = f"{len(failures)} / {n_tests} cases failed"
raise RuntimeError(msg)
else:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("### PASS!")
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print(
'> testing model parallel cuda manual seed with size {} ...'.format(
tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with tensor_parallel.random.get_cuda_rng_tracker().fork():
assert (
torch.cuda.initial_seed() ==
12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()
)
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
import torch
from apex.transformer.tensor_parallel import utils
def test_divide():
assert utils.divide(8, 4) == 2
def test_split_tensor_along_last_dim():
inputy = torch.randn((100, 100, 100))
splits = utils.split_tensor_along_last_dim(inputy, 10)
last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits])
assert torch.equal(last_dim_shapes, torch.full((10,), 10))
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
test_divide()
test_split_tensor_along_last_dim()
print(">> passed the test :-)")
......@@ -101,7 +101,7 @@ class TestBatchSamplerBehavior(unittest.TestCase):
samples2.append(batch)
if i == 4 - 1:
break
torch.testing.assert_allclose(torch.cat(samples), torch.cat(samples2))
torch.testing.assert_close(torch.cat(samples), torch.cat(samples2))
def test_split_batch(self):
......@@ -127,11 +127,6 @@ class TestBatchSamplerBehavior(unittest.TestCase):
global_batch_size = 16
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
batch = next(iter(loader))
# samples = None
# for i, batch in enumerate(loader):
# # samples = batch
# if i == 0:
# break
for _micro_batch_size in (1, 2, 4, 8):
microbatches = list(split_batch_into_microbatch(
......@@ -139,8 +134,6 @@ class TestBatchSamplerBehavior(unittest.TestCase):
_micro_batch_size=_micro_batch_size,
_global_batch_size=global_batch_size,
))
# print(batch)
# print(microbatches)
self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
......
import logging
from typing import Tuple
import torch
import torch.nn.functional as F
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel import cross_entropy
from apex.transformer.testing.commons import set_random_seed, IdentityLayer
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
def torch_cross_entropy(
batch_size: int, seq_length: int, vocab_size: int, logits_scale: float, seed: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
set_random_seed(seed)
identity = IdentityLayer(
(batch_size, seq_length, vocab_size), scale=logits_scale
).cuda()
logits = identity()
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
loss = (
F.cross_entropy(
logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none"
)
.view_as(target)
.mean()
)
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(
batch_size, seq_length, vocab_size, logits_scale, seed
):
set_random_seed(seed)
identity = IdentityLayer(
(batch_size, seq_length, vocab_size), scale=logits_scale
).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
logits_parallel_ = logits_parallel.clone().detach()
loss = cross_entropy.vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
# check for mutation
assert torch.equal(logits_parallel_, logits_parallel)
return loss, identity.weight.grad
class VocabParallelCrossEntropy(DistributedTestBase):
def test_cross_entropy(self):
batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
logits_scale = 1000.0
seed = 1234
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size
loss_torch, grad_torch = torch_cross_entropy(
batch_size, sequence_length, vocab_size, logits_scale, seed
)
(
loss_tensor_parallel,
grad_tensor_parallel,
) = tensor_sharded_cross_entropy(
batch_size, sequence_length, vocab_size, logits_scale, seed
)
torch.testing.assert_close(loss_torch, loss_tensor_parallel)
torch.testing.assert_close(grad_torch, grad_tensor_parallel)
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch.testing
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
class BroadcastDataTest(DistributedTestBase):
def test_broadcast_data(self):
tensor_model_parallel_world_size: int = self.world_size // (
1 + self.world_size > 1
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
target_key_size = {
"key1": [7, 11],
"key2": [8, 2, 1],
"key3": [13],
"key4": [5, 1, 2],
"key5": [5, 12],
}
keys = [k for k in target_key_size]
data = {}
data_t = {}
with torch.no_grad():
for key in target_key_size:
data[key] = torch.randint(0, 1000, size=target_key_size[key])
data_t[key] = data[key].clone()
# "key_x" is supposed to be ignored.
data["key_x"] = torch.rand(5)
data_t["key_x"] = data["key_x"].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
self.assertEqual(target_key_size[key], key_size[key])
broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
torch.testing.assert_close(broadcasted_data[key], data_t[key].cuda())
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
......@@ -15,12 +15,20 @@ def attention_mask_func(attention_scores, attention_mask):
return attention_scores.masked_fill(attention_mask, -10000.0)
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
autocast_dtypes = (
(torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
class TestFusedScaleMaskSoftmax(unittest.TestCase):
def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding):
def _setup_fused_softmax(
self,
input_in_fp16,
input_in_bf16,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
):
fused_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16,
......@@ -47,26 +55,40 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
mask.shape = [4, 1, 24, 24]
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
(torch.half, torch.bfloat16), (None, 2.0), (False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
self._setup_fused_softmax(
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.padding,
)
return
fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
attention_scores_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.padding,
)
attention_scores_0 = (
torch.randn((4, 12, 24, 24))
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_allclose(actual, expected)
torch.testing.assert_close(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
......@@ -80,18 +102,23 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding)
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
)
attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
attention_scores_0 = (
torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True)
attention_scores_1 = (
attention_scores_0.clone().to(dtype).requires_grad_(True)
)
mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
expected = torch_fn(attention_scores_1, mask)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores_0, mask)
self.assertEqual(actual.dtype, dtype)
torch.testing.assert_allclose(actual, expected)
torch.testing.assert_close(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
......@@ -108,9 +135,7 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
upper elements are True and lower elements and diagonal are False.
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
(torch.half, torch.bfloat16), (None, 2.0), (False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
......@@ -118,21 +143,37 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.causal,
)
return
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
attn_weights_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
input_in_fp16,
input_in_bf16,
scale,
softmax_in_fp32,
AttnMaskType.causal,
)
attn_weights_0 = (
torch.randn((4, 12, 24, 24))
.to(device="cuda", dtype=dtype)
.requires_grad_(True)
)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
total_mask = (
~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
.unsqueeze(0)
.unsqueeze(0)
)
total_mask = total_mask.repeat((4, 1, 1, 1))
expected = fused_fn(attn_weights_0, total_mask)
actual = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
torch.testing.assert_close(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
......@@ -146,20 +187,27 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal)
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
)
attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
attn_weights_0 = (
torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
attn_weights_1 = (
attn_weights_0.clone().to(dtype).requires_grad_(True)
)
total_mask = (
~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
.unsqueeze(0)
.unsqueeze(0)
)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attn_weights_0, total_mask)
self.assertEqual(actual.dtype, dtype)
expected = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
torch.testing.assert_close(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
......
import logging
import torch
import torch.nn as nn
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
# N.B. (mkozuki): Disable TF32 matrix multiply.
# Matrices used in this test are so small that TF32 matmul
# can be less precise so that `self.assertEqual` raises.
torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTest(DistributedTestBase):
BATCH_SIZE: int = 17
SEQUENCE_LENGTH: int = 23
VOCAB_SIZE: int = 48
HIDDEN_SIZE: int = 16
INPUT_SIZE_COEFF: int = 13
OUTPUT_SIZE_COEFF: int = 17
SEED: int = 123
def test_parallel_embedding(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(TensorParallelLayerTest.SEED + 1)
input_tensor = torch.randint(
0,
TensorParallelLayerTest.VOCAB_SIZE,
(
TensorParallelLayerTest.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH,
),
device="cuda",
)
loss_weight = torch.randn(
(
TensorParallelLayerTest.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH,
TensorParallelLayerTest.HIDDEN_SIZE,
),
device="cuda",
)
set_random_seed(TensorParallelLayerTest.SEED)
embedding_torch = nn.Embedding(
TensorParallelLayerTest.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE,
).cuda()
output_torch = embedding_torch(input_tensor)
loss_torch = torch.mul(output_torch, loss_weight).sum()
loss_torch.backward()
# N.B. (mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`.
set_random_seed(TensorParallelLayerTest.SEED)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
TensorParallelLayerTest.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE,
init_method=nn.init.normal_,
use_cpu_initialization=True,
).cuda()
output_vocab_parallel = embedding_vocab_parallel(input_tensor)
loss_vocab_parallel = torch.mul(
output_vocab_parallel, loss_weight
).sum()
loss_vocab_parallel.backward()
self.assertEqual(output_torch, output_vocab_parallel)
self.assertEqual(loss_torch, loss_vocab_parallel)
splitted_weight_torch = torch.split(
embedding_torch.weight.grad,
TensorParallelLayerTest.VOCAB_SIZE
// tensor_model_parallel_world_size,
0,
)[parallel_state.get_tensor_model_parallel_rank()]
self.assertEqual(
splitted_weight_torch, embedding_vocab_parallel.weight.grad
)
parallel_state.destroy_model_parallel()
def _affine_weight_init_test_impl(
self, init_device: str, is_column_parallel: bool
) -> None:
dim = int(not is_column_parallel)
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
weight_shape = (
(TensorParallelLayerTest.OUTPUT_SIZE_COEFF, input_size)
if is_column_parallel
else (output_size, TensorParallelLayerTest.INPUT_SIZE_COEFF)
)
weight = torch.empty(weight_shape)
set_random_seed(TensorParallelLayerTest.SEED)
sharding_dim_size = (
TensorParallelLayerTest.OUTPUT_SIZE_COEFF
if is_column_parallel
else TensorParallelLayerTest.INPUT_SIZE_COEFF
)
if init_device == "cpu":
layers._initialize_affine_weight_cpu(
weight,
output_size,
input_size,
sharding_dim_size,
dim,
nn.init.normal_,
params_dtype=torch.float32,
)
else:
layers._initialize_affine_weight_gpu(
weight, torch.nn.init.normal_, dim
)
# Target
set_random_seed(TensorParallelLayerTest.SEED)
if init_device == "cpu":
main_weight = torch.empty(output_size, input_size)
nn.init.normal_(main_weight)
curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[
parallel_state.get_tensor_model_parallel_rank()
]
else:
curr_weight = torch.empty(*weight_shape)
nn.init.normal_(curr_weight)
self.assertEqual(curr_weight, weight)
parallel_state.destroy_model_parallel()
def test_affine_weight_init_column_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True)
def test_affine_weight_init_column_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True)
def test_affine_weight_init_row_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False)
def test_affine_weight_init_row_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)
def test_row_parallel_linear(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
set_random_seed(TensorParallelLayerTest.SEED)
linear_layer = layers.RowParallelLinear(
input_size,
output_size,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
).cuda()
loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, output_size)
).cuda()
# Forward and backward
input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE, input_size, requires_grad=True
).cuda()
input_tensor.retain_grad()
output, _ = linear_layer(input_tensor)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
self.assertIsNotNone(input_tensor.grad)
with torch.no_grad():
dldy = loss_weight.clone()
x = input_tensor.clone()
a = linear_layer.master_weight.cuda()
dlda = torch.matmul(dldy.t(), x)
dldb = torch.matmul(
torch.ones(TensorParallelLayerTest.BATCH_SIZE, 1).cuda().t(), dldy
).view(-1)
dldx = torch.matmul(dldy, a)
with torch.no_grad():
curr_dlda = torch.split(
dlda, TensorParallelLayerTest.INPUT_SIZE_COEFF, dim=1
)[parallel_state.get_tensor_model_parallel_rank()].clone()
self.assertEqual(linear_layer.weight.grad, curr_dlda)
self.assertEqual(input_tensor.grad, dldx)
self.assertEqual(linear_layer.bias.grad, dldb)
parallel_state.destroy_model_parallel()
def test_column_parallel_linear(self):
self._column_parallel_linear_test_impl(False, False)
def test_column_parallel_linear_no_async(self):
self._column_parallel_linear_test_impl(True, False)
def test_column_parallel_linear_gradient_accumulation_fusion(self):
self._column_parallel_linear_test_impl(False, True)
def _column_parallel_linear_test_impl(
self,
no_async_tensor_model_parallel_allreduce: bool,
gradient_accumulation_fusion: bool,
):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
print(
f"tensor_model_parallel_world_size={tensor_model_parallel_world_size}"
)
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
feature_size_coeff = TensorParallelLayerTest.INPUT_SIZE_COEFF
feature_size = feature_size_coeff * tensor_model_parallel_world_size
hidden_size = feature_size
set_random_seed(TensorParallelLayerTest.SEED)
input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE,
hidden_size,
feature_size,
device="cuda",
requires_grad=True,
)
input_tensor.retain_grad()
loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,),
device="cuda",
)
linear = layers.ColumnParallelLinear(
feature_size,
feature_size,
bias=False,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
gradient_accumulation_fusion=gradient_accumulation_fusion,
).cuda()
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.randn_like(linear.weight)
output, _ = linear(input_tensor)
self.assertEqual(
output.shape,
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,),
)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
with torch.no_grad():
dldy = loss_weight.clone()
x = input_tensor.clone()
a = linear.master_weight.cuda().clone()
dldx = torch.matmul(dldy, a)
self.assertEqual(input_tensor.grad, dldx)
# TODO (mkozuki): Cover the other cases.
if (
tensor_model_parallel_world_size == 1
and not gradient_accumulation_fusion
):
dlda = torch.matmul(torch.transpose(dldy, 1, 2), x).sum(dim=0)
curr_dlda = torch.split(dlda, feature_size_coeff, dim=0)[
parallel_state.get_tensor_model_parallel_rank()
]
self.assertEqual(linear.weight.grad, curr_dlda)
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class MappingTest(DistributedTestBase):
def test_reduce(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
expected = torch.full(
(10, 10, 10, 10),
50 * tensor_model_paralell_world_size,
device=f"cuda:{self.rank}",
)
self.assertTrue(torch.equal(mappings._reduce(t), expected))
parallel_state.destroy_model_parallel()
def test_split(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
tensors = [
torch.randn(10, 1)
for rank in range(tensor_model_paralell_world_size)
]
x = torch.cat(tensors, 1)
out = mappings._split(x)
self.assertTrue(
torch.equal(
out, tensors[parallel_state.get_tensor_model_parallel_rank()]
)
)
parallel_state.destroy_model_parallel()
def test_gather(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = f"cuda:{self.rank}"
gathered = mappings._gather(
torch.tensor(
[parallel_state.get_tensor_model_parallel_rank()], device=device
)
)
expected = torch.tensor(
[rank for rank in range(tensor_model_paralell_world_size)],
device=device,
)
self.assertTrue(torch.equal(gathered, expected))
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import os
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
os.environ["BACKEND"] = "NCCL"
DATA_PARALLEL_WORLD_SIZE: int = 1
def calc_expected_tensor_model_paralell_rank(
rank: int, tensor_model_parallel_world_size: int,
) -> int:
return rank % tensor_model_parallel_world_size
class ParallelStateTest(DistributedTestBase):
def test_initialize_model_parallel(self) -> None:
self.assertFalse(parallel_state.model_parallel_is_initialized())
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
)
self.assertEqual(
tensor_model_parallel_world_size,
parallel_state.get_tensor_model_parallel_world_size(),
)
expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
)
self.assertEqual(
expected_tensor_model_parallel_rank,
parallel_state.get_tensor_model_parallel_rank(),
)
expected_tensor_model_parallel_src_rank = (
self.rank // tensor_model_parallel_world_size
) * tensor_model_parallel_world_size
self.assertEqual(
expected_tensor_model_parallel_src_rank,
parallel_state.get_tensor_model_parallel_src_rank(),
)
parallel_state.destroy_model_parallel()
self.assertFalse(parallel_state.model_parallel_is_initialized())
def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
if self.world_size < 4:
self.skipTest("requires >= 4 GPUs")
self.assertFalse(parallel_state.model_parallel_is_initialized())
tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
virtual_pipeline_model_parallel_world_size = 2
pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size,
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
)
self.assertEqual(
calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
),
parallel_state.get_tensor_model_parallel_rank(),
)
self.assertEqual(
pipeline_model_parallel_world_size,
parallel_state.get_pipeline_model_parallel_world_size(),
)
self.assertEqual(
virtual_pipeline_model_parallel_world_size,
parallel_state.get_virtual_pipeline_model_parallel_world_size(),
)
expected_pipeline_rank = (
self.rank - (self.rank % tensor_model_parallel_world_size)
) % pipeline_model_parallel_world_size
self.assertEqual(
expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(),
)
# virtual pipeline model parallel rank is lazily set, i.e., right after the call of
# `initialize_model_parallel`, it's set to 0.
self.assertEqual(
0, parallel_state.get_virtual_pipeline_model_parallel_rank(),
)
self.assertEqual(
pipeline_model_parallel_split_rank,
parallel_state.get_pipeline_model_parallel_split_rank(),
)
fake_split_rank = 77
parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
self.assertEqual(
fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()
)
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
import logging
import itertools
from typing import Optional
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import utils as pp_utils
from apex.transformer.pipeline_parallel.schedules.common import (
FwdStepFunc,
build_model,
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing import commons as testing_utils
logging.getLogger("apex").setLevel(logging.WARNING)
class PipelineParallelForwardBackwardTest(DistributedTestBase):
GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 1
HIDDEN_SIZE = 32
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _forward_backward_test_impl(
self,
forward_only: bool,
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
vriatual_pipeline_model_parallel_size: Optional[int],
) -> None:
for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False),
):
grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0)
if dtype == torch.half
else None
)
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (self.world_size >= 8 and self.world_size % 2 == 0)
pipeline_model_parallel_world_size = (
self.world_size
// (tensor_model_parallel_world_size * data_parallel_size)
if pipeline_model_parallel_world_size is None
else 1
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=vriatual_pipeline_model_parallel_size,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE,
micro_batch_size=PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
global_batch_shape = (
PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(),
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
)
batch = (torch.randn(global_batch_shape).cuda(),)
model = build_model(
testing_utils.model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=vriatual_pipeline_model_parallel_size,
hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
)
_param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
pp_utils.update_num_microbatches(0)
fwd_bwd_func(
testing_utils.fwd_step_func,
batch,
model,
forward_only=forward_only,
# `tensor_shape` is the shape of micro batch.
tensor_shape=(
PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
),
dtype=dtype,
grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
if not forward_only:
for m in model:
for p in m.parameters():
self.assertIsNotNone(p.grad)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
parallel_state.destroy_model_parallel()
def test_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)
def test_no_pipelining_inference(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_inference(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, 2, None
)
def test_pipelining_with_interleaving_inference(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, 2, None
)
if __name__ == "__main__":
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment