Unverified Commit 365fdc18 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

transformer utils (#1181)


Co-authored-by: default avatarPiotr Bialecki <pbialecki@nvidia.com>
Co-authored-by: default avatarEddie Yan <eddiey@nvidia.com>
Co-authored-by: default avatarRishi Puri <riship@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
parent bdac244e
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // tensor_model_parallel_size,
1)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // tensor_model_parallel_size,
0)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(tensor_model_parallel_size, device):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype,
)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(
weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m, n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = parallel_state.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = parallel_state.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight cpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(tensor_model_parallel_size, 'cpu')
tensor_model_parallel_size *= 2
# Reset groups
parallel_state.destroy_model_parallel()
print_separator('test initialize affine weight gpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(tensor_model_parallel_size, 'gpu')
tensor_model_parallel_size *= 2
# Deleted, replaced with vocab parallel embedding?
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# print_separator('test parallel embedding')
# test_parallel_embedding(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
print_separator('test column-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# Deleted
#print_separator('test parallel self-attention')
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# test_parallel_self_attention(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
#Deleted because PararallelTransformerLayer no longer exists
# print_separator('test parallel transformer')
# tensor_model_parallel_size = 1
# while tensor_model_parallel_size <= world_size:
# test_parallel_transformer_layer(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel import mappings
from apex.transformer.tensor_parallel.tests import global_vars
global_vars.set_global_variables()
def test__reduce(args, tensor_model_parallel_size):
print("Testing reduction size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._reduce(torch.full((10, 10, 10, 10), (50))),
torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
)
parallel_state.destroy_model_parallel()
print("Passed!")
def test__split(args, tensor_model_parallel_size):
print("Testing splitting size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
listy = []
for i in range(tensor_model_parallel_size):
listy.append(torch.randn(10, 1))
x = torch.cat(tuple(listy), 1)
out = mappings._split(x)
assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()])
parallel_state.destroy_model_parallel()
print("Passed!")
def test__gather(args, tensor_model_parallel_size):
print("Testing gathering size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._gather(torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
torch.tensor(list(range(tensor_model_parallel_size))),
)
parallel_state.destroy_model_parallel()
print("Passed!")
if __name__ == "__main__":
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test__reduce(args, tensor_model_parallel_size)
test__split(args, tensor_model_parallel_size)
test__gather(args, tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print(">> passed the test :-)")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print(
'> testing model parallel cuda manual seed with size {} ...'.format(
tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with tensor_parallel.random.get_cuda_rng_tracker().fork():
assert (
torch.cuda.initial_seed() ==
12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()
)
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
import torch
from apex.transformer.tensor_parallel import utils
def test_divide():
assert utils.divide(8, 4) == 2
def test_split_tensor_along_last_dim():
inputy = torch.randn((100, 100, 100))
splits = utils.split_tensor_along_last_dim(inputy, 10)
last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits])
assert torch.equal(last_dim_shapes, torch.full((10,), 10))
if __name__ == "__main__":
test_divide()
test_split_tensor_along_last_dim()
print(">> passed the test :-)")
"""Test for fused softmax functions.
Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
""" # NOQA
import itertools
import unittest
import torch
from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestFusedScaleMaskSoftmax(unittest.TestCase):
def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding):
fused_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16,
mask_func=attention_mask_func,
scale=scale,
softmax_in_fp32=softmax_in_fp32,
attn_mask_type=attn_mask_type,
scaled_masked_softmax_fusion=True,
)
torch_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16,
mask_func=attention_mask_func,
scale=scale,
softmax_in_fp32=softmax_in_fp32,
attn_mask_type=attn_mask_type,
scaled_masked_softmax_fusion=False,
)
return fused_fn, torch_fn
def test_fused_scale_mask_softmax(self):
"""
attention_scores.shape = [4, 12, 24, 24]
mask.shape = [4, 1, 24, 24]
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
return
fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
attention_scores = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
reference = fused_fn(attention_scores, mask)
actual = torch_fn(attention_scores, mask)
torch.testing.assert_allclose(actual, reference)
def test_autocast_fused_scale_mask_softmax(self):
for dtype in autocast_dtypes:
with self.subTest(f"{dtype}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding)
attention_scores = torch.randn((4, 12, 24, 24)).cuda()
mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores, mask)
self.assertEqual(actual.dtype, dtype)
with torch.no_grad():
expected = torch_fn(attention_scores.to(dtype), mask)
torch.testing.assert_allclose(actual, expected)
def test_fused_upper_triangle_mask_softmax(self):
"""
attn_weights.shape: [4, 12, 24, 24]
total_mask.shape: [4, 1, 24, 24]
total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
upper elements are True and lower elements and diagonal are False.
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
return
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
attn_weights = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
total_mask = total_mask.repeat((4, 1, 1, 1))
reference = fused_fn(attn_weights, total_mask)
actual = torch_fn(attn_weights, total_mask)
torch.testing.assert_allclose(actual, reference)
def test_autocast_fused_upper_triangle_mask_softmax(self):
for dtype in autocast_dtypes:
with self.subTest(f"{dtype}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal)
attn_weights = torch.randn((4, 12, 24, 24)).cuda()
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attn_weights, total_mask)
self.assertEqual(actual.dtype, dtype)
with torch.no_grad():
expected = torch_fn(attn_weights.to(dtype), total_mask)
torch.testing.assert_allclose(actual, expected)
import os
import subprocess
import sys
import unittest
def run_mpu_tests():
python_executable_path = sys.executable
# repository_root = os.path.join(os.path.dirname(__file__), "../../../")
# directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
directory = os.path.dirname(__file__)
files = [
os.path.join(directory, f) for f in os.listdir(directory)
if f.startswith("run_") and os.path.isfile(os.path.join(directory, f))
]
print("#######################################################")
print(f"# Python executable path: {python_executable_path}")
print(f"# {len(files)} tests: {files}")
print("#######################################################")
errors = []
for i, test_file in enumerate(files, 1):
test_run_cmd = f"NVIDIA_TF32_OVERRIDE=0 {python_executable_path} {test_file} --micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings 32 --encoder-seq-length 32 --use-cpu-initialization" # NOQA
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
try:
output = subprocess.check_output(
test_run_cmd, shell=True
).decode(sys.stdout.encoding).strip()
except Exception as e:
errors.append((test_file, str(e)))
else:
if '>> passed the test :-)' not in output:
errors.append(test_file, output)
else:
if not errors:
print("### PASSED")
else:
print("### FAILED")
short_msg = f"{len(errors)} out of {len(files)} tests failed"
print(short_msg)
for (filename, log) in errors:
print(f"File: {filename}\nLog: {log}")
raise RuntimeError(short_msg)
class TestMPU(unittest.TestCase):
def test_mpu(self):
run_mpu_tests()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment