"vscode:/vscode.git/clone" did not exist on "bfdd1eaa446bd58ec35dbb54e247abed11c70084"
Commit 30f5009a authored by Tom Birch's avatar Tom Birch Committed by Mandeep Singh Baines
Browse files

[feat] Model parallel (#3)

parent 8634280c
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
from .. import Tensor
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self):
return self.weight
def set_random_seed(seed):
"""Set random seed for reproducability."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
model_parallel_cuda_manual_seed(seed)
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def get_world_sizes():
limit = torch.cuda.device_count()
return [x for x in [1, 2, 4, 8] if x <= limit]
def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes()):
for world_size in world_sizes:
mp.spawn(test_func, args=(world_size,), nprocs=world_size, join=True)
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy
from fairscale.nn.model_parallel.mappings import scatter_to_model_parallel_region
from tests.nn.model_parallel.commons import IdentityLayer, dist_init, set_random_seed, spawn_for_all_world_sizes
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none").view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda()
logits = identity()
logits_parallel = scatter_to_model_parallel_region(logits)
target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
return loss, identity.weight.grad
def run_test_cross_entropy(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing cross entropy with model parallel size {} ...".format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(" max error in loss on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(" max error in grad on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def test_cross_entropy():
spawn_for_all_world_sizes(run_test_cross_entropy)
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fairscale.nn.model_parallel import initialize as mpu
from tests.nn.model_parallel.commons import dist_init, spawn_for_all_world_sizes
def run_test_initialize_model_parallel(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size))
model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_)
assert mpu.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size
assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_):
dist_init(rank, model_parallel_size_)
if torch.distributed.get_rank() == 0:
print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_))
model_parallel_size = min(model_parallel_size_, torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size)
assert mpu.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def test_initialize_model_parallel():
spawn_for_all_world_sizes(run_test_initialize_model_parallel)
def test_get_model_parallel_src_rank():
spawn_for_all_world_sizes(run_test_get_model_parallel_src_rank)
def test_adjacency(monkeypatch):
new_groups = []
data_parallel_size = 32
pipeline_length = 8
model_parallel_size = 4
class MockDistribued:
def get_rank(self):
return 0
def is_initialized(self):
return True
def get_world_size(self):
return data_parallel_size * pipeline_length * model_parallel_size
def new_group(self, args):
new_groups.append(args.copy())
return ()
monkeypatch.setattr(torch, "distributed", MockDistribued())
mpu.initialize_model_parallel(model_parallel_size, pipeline_length)
from collections import defaultdict
buckets = defaultdict(list)
for group in new_groups:
buckets[len(group)].append(group)
assert sorted(list(buckets.keys())) == [model_parallel_size, data_parallel_size]
assert len(buckets[model_parallel_size]) == pipeline_length * data_parallel_size
assert len(buckets[data_parallel_size]) == model_parallel_size * pipeline_length
# Check that model_parallel groups are contiguous
for group in buckets[model_parallel_size]:
assert sorted(group) == group
assert list(range(group[0], group[-1] + 1)) == group
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
import torch.nn.init as init
from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe
from tests.nn.model_parallel.commons import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes
def run_test_parallel_embedding(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing parallel embedding with model parallel size {} ...".format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(" error in loss (parallel) on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-12, "error: {}".format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(" error in loss (vocab parallel) on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-12, "error: {}".format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, hidden_size // model_parallel_size, 1)[
mpu.get_model_parallel_rank()
]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(" error in grad (parallel) on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-12, "error: {}".format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, vocab_size // model_parallel_size, 0)[
mpu.get_model_parallel_rank()
]
error = embedding_vocab_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(" error in grad (vocab parallel) on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-12, "error: {}".format(error)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def run_test_initialize_affine_weight(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing initialize_affine_weight with model parallel size: {}".format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
layers._initialize_affine_weight(weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(
" column parallel max error (should be zero) on global rank {}: {}".format(
torch.distributed.get_rank(), error
)
)
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
layers._initialize_affine_weight(weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(
" row parallel max error (should be zero) on global rank {}: {}".format(torch.distributed.get_rank(), error)
)
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(" >> passed the test :-)")
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m, n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def run_test_column_parallel_linear(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing ColumnParallelLinear with model parallel size: {}".format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff, dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(" >> passed the test :-)")
def run_test_row_parallel_linear(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing RowParallelLinear with model parallel size: {}".format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.RowParallelLinear(input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff, dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(" >> passed the test :-)")
def run_test_pipe(rank, model_parallel_size):
pipe_world_size = 2
dist_init(rank, model_parallel_size)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size
)
)
model_parallel_size = mpu.get_model_parallel_world_size()
chunk_size = 8
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7 * chunk_size
identity = IdentityLayer2D(batch_size, input_size).cuda()
pipeline_devices = mpu.get_pipeline_parallel_group()
if pipe_world_size == 2 and len(pipeline_devices) == 1:
pipeline_devices.append(pipeline_devices[0] + model_parallel_size)
set_random_seed(seed)
model = nn.Sequential(
layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True, bias=False).cuda(),
nn.ReLU(),
layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(),
)
set_random_seed(seed)
reference = nn.Sequential(
nn.Linear(input_size, output_size, bias=False).cuda(),
nn.ReLU(),
nn.Linear(output_size, input_size, bias=False).cuda(),
)
reference[0].weight.data = model[0].master_weight.cuda()
reference[-1].weight.data = model[-1].master_weight.cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
output = model(identity())
reference_output = reference(identity())
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
if pipe_world_size == 2:
pipe_model = Pipe(model, [2, 1], devices=pipeline_devices, chunks=chunk_size)
torch.distributed.barrier()
pipe_output = pipe_model(identity())
error = reference_output.sub(pipe_output.cuda()).max()
torch.distributed.barrier()
assert error < 1.0e-6
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def test_affine_weight():
spawn_for_all_world_sizes(run_test_initialize_affine_weight)
def test_embedding():
spawn_for_all_world_sizes(run_test_parallel_embedding)
def test_column_parallel():
spawn_for_all_world_sizes(run_test_column_parallel_linear)
def test_row_parallel():
spawn_for_all_world_sizes(run_test_row_parallel_linear)
def test_pipe():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, world_sizes)
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import random
from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed
from tests.nn.model_parallel.commons import dist_init, spawn_for_all_world_sizes
def run_test_set_cuda_rng_state(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing set_rng_state with size {} ...".format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(1234)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(
" max diff in rng state (should be non-zero) on global rank {}: {}".format(
torch.distributed.get_rank(), max_diff
)
)
assert max_diff > 0
# Reset the rng state and do the same stuff.
random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(
" max error in generated tensors (should be zero) on global rank {}: {}".format(
torch.distributed.get_rank(), error
)
)
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(
" max error in rng state (should be zero) on global rank {}: {}".format(torch.distributed.get_rank(), error)
)
assert error == 0
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def run_test_cuda_rng_tracker(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing cuda rng tracker with size {} ...".format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
get_cuda_rng_tracker().add("test", seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(
" max diff in generated tensors (should be non-zero) on global rank {}: {}".format(
torch.distributed.get_rank(), diff
)
)
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(
" max error in generated tensors (should be zero) on global rank {}: {}".format(
torch.distributed.get_rank(), error
)
)
assert error < 1.0e-6
# Reset the tracker
get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 + mpu.get_model_parallel_rank())
# Reset the tracker
get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test :-)")
def test_set_cuda_rng_state():
spawn_for_all_world_sizes(run_test_set_cuda_rng_state)
def test_cuda_rng_tracker():
spawn_for_all_world_sizes(run_test_cuda_rng_tracker)
def test_model_parallel_cuda_manual_seed():
spawn_for_all_world_sizes(run_test_model_parallel_cuda_manual_seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment