Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
30f5009a
Commit
30f5009a
authored
Jul 22, 2020
by
Tom Birch
Committed by
Mandeep Singh Baines
Jul 31, 2020
Browse files
[feat] Model parallel (#3)
parent
8634280c
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
889 additions
and
0 deletions
+889
-0
stubs/torch/utils/checkpoint.pyi
stubs/torch/utils/checkpoint.pyi
+6
-0
tests/nn/model_parallel/commons.py
tests/nn/model_parallel/commons.py
+64
-0
tests/nn/model_parallel/test_cross_entropy.py
tests/nn/model_parallel/test_cross_entropy.py
+88
-0
tests/nn/model_parallel/test_initialize.py
tests/nn/model_parallel/test_initialize.py
+136
-0
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+386
-0
tests/nn/model_parallel/test_random.py
tests/nn/model_parallel/test_random.py
+209
-0
No files found.
stubs/torch/utils/checkpoint.pyi
0 → 100644
View file @
30f5009a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
from .. import Tensor
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
tests/nn/model_parallel/commons.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
numpy
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
fairscale.nn.model_parallel.random
import
model_parallel_cuda_manual_seed
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
model_parallel_cuda_manual_seed
(
seed
)
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"nccl"
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
def
get_world_sizes
():
limit
=
torch
.
cuda
.
device_count
()
return
[
x
for
x
in
[
1
,
2
,
4
,
8
]
if
x
<=
limit
]
def
spawn_for_all_world_sizes
(
test_func
,
world_sizes
=
get_world_sizes
()):
for
world_size
in
world_sizes
:
mp
.
spawn
(
test_func
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
tests/nn/model_parallel/test_cross_entropy.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
tests.nn.model_parallel.commons
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
set_random_seed
(
seed
)
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
F
.
cross_entropy
(
logits
.
view
(
-
1
,
logits
.
size
()[
-
1
]),
target
.
view
(
-
1
),
reduction
=
"none"
).
view_as
(
target
).
mean
()
loss
.
backward
()
return
loss
,
identity
.
weight
.
grad
def
mpu_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
set_random_seed
(
seed
)
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
logits_parallel
=
scatter_to_model_parallel_region
(
logits
)
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
vocab_parallel_cross_entropy
(
logits_parallel
,
target
).
mean
()
loss
.
backward
()
return
loss
,
identity
.
weight
.
grad
def
run_test_cross_entropy
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing cross entropy with model parallel size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
batch_size
=
13
seq_length
=
17
vocab_size_per_partition
=
11
logits_scale
=
1000.0
vocab_size
=
vocab_size_per_partition
*
model_parallel_size
seed
=
1234
loss_torch
,
grad_torch
=
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
)
loss_mpu
,
grad_mpu
=
mpu_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
)
error
=
loss_torch
.
sub_
(
loss_mpu
).
abs
().
max
()
print
(
" max error in loss on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
grad_torch
.
sub_
(
grad_mpu
).
abs
().
max
()
print
(
" max error in grad on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
test_cross_entropy
():
spawn_for_all_world_sizes
(
run_test_cross_entropy
)
tests/nn/model_parallel/test_initialize.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing initialize_model_parallel with size {} ..."
.
format
(
model_parallel_size
))
model_parallel_size_
=
min
(
model_parallel_size
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
model_parallel_size_
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks.
def
check
(
group
,
world_size
,
rank
):
assert
world_size
==
torch
.
distributed
.
get_world_size
(
group
=
group
)
assert
rank
==
torch
.
distributed
.
get_rank
(
group
=
group
)
# Model parallel.
world_size
=
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
model_parallel_size_
assert
world_size
==
mpu
.
get_model_parallel_world_size
()
assert
rank
==
mpu
.
get_model_parallel_rank
()
check
(
mpu
.
get_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
model_parallel_size
assert
world_size
==
mpu
.
get_data_parallel_world_size
()
assert
rank
==
mpu
.
get_data_parallel_rank
()
check
(
mpu
.
get_data_parallel_group
(),
world_size
,
rank
)
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_get_model_parallel_src_rank
(
rank
,
model_parallel_size_
):
dist_init
(
rank
,
model_parallel_size_
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing get_model_parallel_src_rank with size {} ..."
.
format
(
model_parallel_size_
))
model_parallel_size
=
min
(
model_parallel_size_
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
model_parallel_size
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_model_parallel_rank
()
assert
mpu
.
get_model_parallel_src_rank
()
==
src_rank
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
test_initialize_model_parallel
():
spawn_for_all_world_sizes
(
run_test_initialize_model_parallel
)
def
test_get_model_parallel_src_rank
():
spawn_for_all_world_sizes
(
run_test_get_model_parallel_src_rank
)
def
test_adjacency
(
monkeypatch
):
new_groups
=
[]
data_parallel_size
=
32
pipeline_length
=
8
model_parallel_size
=
4
class
MockDistribued
:
def
get_rank
(
self
):
return
0
def
is_initialized
(
self
):
return
True
def
get_world_size
(
self
):
return
data_parallel_size
*
pipeline_length
*
model_parallel_size
def
new_group
(
self
,
args
):
new_groups
.
append
(
args
.
copy
())
return
()
monkeypatch
.
setattr
(
torch
,
"distributed"
,
MockDistribued
())
mpu
.
initialize_model_parallel
(
model_parallel_size
,
pipeline_length
)
from
collections
import
defaultdict
buckets
=
defaultdict
(
list
)
for
group
in
new_groups
:
buckets
[
len
(
group
)].
append
(
group
)
assert
sorted
(
list
(
buckets
.
keys
()))
==
[
model_parallel_size
,
data_parallel_size
]
assert
len
(
buckets
[
model_parallel_size
])
==
pipeline_length
*
data_parallel_size
assert
len
(
buckets
[
data_parallel_size
])
==
model_parallel_size
*
pipeline_length
# Check that model_parallel groups are contiguous
for
group
in
buckets
[
model_parallel_size
]:
assert
sorted
(
group
)
==
group
assert
list
(
range
(
group
[
0
],
group
[
-
1
]
+
1
))
==
group
tests/nn/model_parallel/test_layers.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing parallel embedding with model parallel size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
batch_size
=
17
seq_length
=
23
vocab_size
=
48
hidden_size
=
16
seed
=
1236
set_random_seed
(
123
)
input_data
=
torch
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
seq_length
,
hidden_size
]).
cuda
()
set_random_seed
(
seed
)
embedding_original
=
torch
.
nn
.
Embedding
(
vocab_size
,
hidden_size
).
cuda
()
output
=
embedding_original
(
input_data
)
loss_original
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_original
.
backward
()
set_random_seed
(
seed
)
embedding_parallel
=
layers
.
ParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_parallel
(
input_data
)
loss_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_parallel
.
backward
()
set_random_seed
(
seed
)
embedding_vocab_parallel
=
layers
.
VocabParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_vocab_parallel
(
input_data
)
loss_vocab_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_vocab_parallel
.
backward
()
torch
.
distributed
.
barrier
()
error
=
loss_parallel
.
sub
(
loss_original
).
abs
()
print
(
" error in loss (parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
torch
.
distributed
.
barrier
()
error
=
loss_vocab_parallel
.
sub
(
loss_original
).
abs
()
print
(
" error in loss (vocab parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
hidden_size
//
model_parallel_size
,
1
)[
mpu
.
get_model_parallel_rank
()
]
error
=
embedding_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
" error in grad (parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
vocab_size
//
model_parallel_size
,
0
)[
mpu
.
get_model_parallel_rank
()
]
error
=
embedding_vocab_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
" error in grad (vocab parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_initialize_affine_weight
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing initialize_affine_weight with model parallel size: {}"
.
format
(
model_parallel_size
))
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed
=
12345
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
# ---------------
# Column parallel
# ---------------
weight
=
torch
.
empty
(
output_size_coeff
,
input_size
)
set_random_seed
(
seed
)
layers
.
_initialize_affine_weight
(
weight
,
output_size
,
input_size
,
output_size_coeff
,
0
,
torch
.
nn
.
init
.
normal_
)
# Target.
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
# Compare.
error
=
weight
.
sub
(
my_weight
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" column parallel max error (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# ------------
# Row parallel
# ------------
weight
=
torch
.
empty
(
output_size
,
input_size_coeff
)
set_random_seed
(
seed
)
layers
.
_initialize_affine_weight
(
weight
,
output_size
,
input_size
,
input_size_coeff
,
1
,
torch
.
nn
.
init
.
normal_
)
# Target.
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
# Compare.
error
=
weight
.
sub
(
my_weight
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" row parallel max error (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
" >> passed the test :-)"
)
class
IdentityLayer2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
):
super
(
IdentityLayer2D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
def
run_test_column_parallel_linear
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing ColumnParallelLinear with model parallel size: {}"
.
format
(
model_parallel_size
))
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
7
# Network
identity_layer
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
linear_layer
=
layers
.
ColumnParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
linear_layer
(
input_
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
# Values.
dLdY
=
loss_weight
X
=
identity_layer
.
weight
A
=
linear_layer
.
master_weight
.
cuda
()
dLdA
=
torch
.
matmul
(
dLdY
.
t
(),
X
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdA on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
my_dLdb
=
torch
.
split
(
dLdb
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdb
.
sub
(
linear_layer
.
bias
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdb on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdX
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdX on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
" >> passed the test :-)"
)
def
run_test_row_parallel_linear
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing RowParallelLinear with model parallel size: {}"
.
format
(
model_parallel_size
))
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
7
# Network
identity_layer
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
linear_layer
=
layers
.
RowParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
linear_layer
(
input_
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
# Values.
dLdY
=
loss_weight
X
=
identity_layer
.
weight
A
=
linear_layer
.
master_weight
.
cuda
()
dLdA
=
torch
.
matmul
(
dLdY
.
t
(),
X
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdA on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdb
.
sub
(
linear_layer
.
bias
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdb on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdX
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdX on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
" >> passed the test :-)"
)
def
run_test_pipe
(
rank
,
model_parallel_size
):
pipe_world_size
=
2
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}"
.
format
(
model_parallel_size
,
pipe_world_size
)
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
chunk_size
=
8
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
7
*
chunk_size
identity
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
pipeline_devices
=
mpu
.
get_pipeline_parallel_group
()
if
pipe_world_size
==
2
and
len
(
pipeline_devices
)
==
1
:
pipeline_devices
.
append
(
pipeline_devices
[
0
]
+
model_parallel_size
)
set_random_seed
(
seed
)
model
=
nn
.
Sequential
(
layers
.
ColumnParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
,
bias
=
False
).
cuda
(),
nn
.
ReLU
(),
layers
.
RowParallelLinear
(
output_size
,
input_size
,
keep_master_weight_for_test
=
True
,
bias
=
False
).
cuda
(),
)
set_random_seed
(
seed
)
reference
=
nn
.
Sequential
(
nn
.
Linear
(
input_size
,
output_size
,
bias
=
False
).
cuda
(),
nn
.
ReLU
(),
nn
.
Linear
(
output_size
,
input_size
,
bias
=
False
).
cuda
(),
)
reference
[
0
].
weight
.
data
=
model
[
0
].
master_weight
.
cuda
()
reference
[
-
1
].
weight
.
data
=
model
[
-
1
].
master_weight
.
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
output
=
model
(
identity
())
reference_output
=
reference
(
identity
())
error
=
reference_output
.
sub
(
output
).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
if
pipe_world_size
==
2
:
pipe_model
=
Pipe
(
model
,
[
2
,
1
],
devices
=
pipeline_devices
,
chunks
=
chunk_size
)
torch
.
distributed
.
barrier
()
pipe_output
=
pipe_model
(
identity
())
error
=
reference_output
.
sub
(
pipe_output
.
cuda
()).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
test_affine_weight
():
spawn_for_all_world_sizes
(
run_test_initialize_affine_weight
)
def
test_embedding
():
spawn_for_all_world_sizes
(
run_test_parallel_embedding
)
def
test_column_parallel
():
spawn_for_all_world_sizes
(
run_test_column_parallel_linear
)
def
test_row_parallel
():
spawn_for_all_world_sizes
(
run_test_row_parallel_linear
)
def
test_pipe
():
world_sizes
=
[
x
for
x
in
get_world_sizes
()
if
x
<=
torch
.
cuda
.
device_count
()
/
2
]
spawn_for_all_world_sizes
(
run_test_pipe
,
world_sizes
)
tests/nn/model_parallel/test_random.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
random
from
fairscale.nn.model_parallel.random
import
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_set_cuda_rng_state
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing set_rng_state with size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
size
=
123
seed
=
1234
torch
.
cuda
.
manual_seed
(
1234
)
tensor
=
torch
.
cuda
.
FloatTensor
(
size
)
# Get the state
rng_state
=
torch
.
cuda
.
get_rng_state
()
rng_state_copy
=
rng_state
.
clone
()
# Do some stuff.
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
result_1
=
tensor
.
clone
()
assert
rng_state
.
sub
(
rng_state_copy
).
max
()
==
0
assert
torch
.
cuda
.
get_rng_state
().
sub
(
rng_state_copy
).
max
()
>
0
# State should be different.
new_rng_state
=
torch
.
cuda
.
get_rng_state
()
max_diff
=
new_rng_state
.
sub
(
rng_state
).
max
()
print
(
" max diff in rng state (should be non-zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
max_diff
)
)
assert
max_diff
>
0
# Reset the rng state and do the same stuff.
random
.
_set_cuda_rng_state
(
rng_state
)
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
random
.
_set_cuda_rng_state
(
rng_state
)
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
result_2
=
tensor
.
clone
()
# Results should be the same
error
=
result_2
.
sub
(
result_1
).
abs
().
max
()
print
(
" max error in generated tensors (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# Input state should have remained intact.
error
=
rng_state
.
sub
(
rng_state_copy
).
max
()
print
(
" max error in rng state (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
==
0
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_cuda_rng_tracker
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing cuda rng tracker with size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed_1
=
1234
seed_2
=
4321
size
=
[
12
,
21
]
tensor
=
torch
.
cuda
.
FloatTensor
(
size
)
# Set to seed_1 and generate two tensors.
torch
.
cuda
.
manual_seed
(
seed_1
)
torch
.
randn
(
size
,
out
=
tensor
)
target_11
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
target_12
=
tensor
.
clone
()
# Set to seed_2 and generate two tensors.
torch
.
cuda
.
manual_seed
(
seed_2
)
torch
.
randn
(
size
,
out
=
tensor
)
target_21
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
target_22
=
tensor
.
clone
()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch
.
cuda
.
manual_seed
(
seed_1
)
get_cuda_rng_tracker
().
add
(
"test"
,
seed_2
)
torch
.
randn
(
size
,
out
=
tensor
)
result_11
=
tensor
.
clone
()
with
get_cuda_rng_tracker
().
fork
(
"test"
):
torch
.
randn
(
size
,
out
=
tensor
)
result_21
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
result_12
=
tensor
.
clone
()
with
get_cuda_rng_tracker
().
fork
(
"test"
):
torch
.
randn
(
size
,
out
=
tensor
)
result_22
=
tensor
.
clone
()
diff
=
result_11
.
sub
(
result_21
).
abs
().
max
()
diff
=
min
(
diff
,
result_12
.
sub
(
result_22
).
abs
().
max
())
print
(
" max diff in generated tensors (should be non-zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
diff
)
)
assert
diff
>
1.0e-6
error
=
max
(
result_11
.
sub
(
target_11
).
abs
().
max
(),
result_12
.
sub
(
target_12
).
abs
().
max
())
error
=
max
(
error
,
result_21
.
sub
(
target_21
).
abs
().
max
())
error
=
max
(
error
,
result_22
.
sub
(
target_22
).
abs
().
max
())
print
(
" max error in generated tensors (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# Reset the tracker
get_cuda_rng_tracker
().
reset
()
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_model_parallel_cuda_manual_seed
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing model parallel cuda manual seed with size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
model_parallel_cuda_manual_seed
(
12345
)
assert
torch
.
cuda
.
initial_seed
()
==
12345
with
get_cuda_rng_tracker
().
fork
():
assert
torch
.
cuda
.
initial_seed
()
==
(
12345
+
2718
+
mpu
.
get_model_parallel_rank
())
# Reset the tracker
get_cuda_rng_tracker
().
reset
()
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
test_set_cuda_rng_state
():
spawn_for_all_world_sizes
(
run_test_set_cuda_rng_state
)
def
test_cuda_rng_tracker
():
spawn_for_all_world_sizes
(
run_test_cuda_rng_tracker
)
def
test_model_parallel_cuda_manual_seed
():
spawn_for_all_world_sizes
(
run_test_model_parallel_cuda_manual_seed
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment