Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
30f5009a
"vscode:/vscode.git/clone" did not exist on "bfdd1eaa446bd58ec35dbb54e247abed11c70084"
Commit
30f5009a
authored
Jul 22, 2020
by
Tom Birch
Committed by
Mandeep Singh Baines
Jul 31, 2020
Browse files
[feat] Model parallel (#3)
parent
8634280c
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
889 additions
and
0 deletions
+889
-0
stubs/torch/utils/checkpoint.pyi
stubs/torch/utils/checkpoint.pyi
+6
-0
tests/nn/model_parallel/commons.py
tests/nn/model_parallel/commons.py
+64
-0
tests/nn/model_parallel/test_cross_entropy.py
tests/nn/model_parallel/test_cross_entropy.py
+88
-0
tests/nn/model_parallel/test_initialize.py
tests/nn/model_parallel/test_initialize.py
+136
-0
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+386
-0
tests/nn/model_parallel/test_random.py
tests/nn/model_parallel/test_random.py
+209
-0
No files found.
stubs/torch/utils/checkpoint.pyi
0 → 100644
View file @
30f5009a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
from .. import Tensor
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
tests/nn/model_parallel/commons.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
numpy
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
fairscale.nn.model_parallel.random
import
model_parallel_cuda_manual_seed
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
model_parallel_cuda_manual_seed
(
seed
)
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"nccl"
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
def
get_world_sizes
():
limit
=
torch
.
cuda
.
device_count
()
return
[
x
for
x
in
[
1
,
2
,
4
,
8
]
if
x
<=
limit
]
def
spawn_for_all_world_sizes
(
test_func
,
world_sizes
=
get_world_sizes
()):
for
world_size
in
world_sizes
:
mp
.
spawn
(
test_func
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
tests/nn/model_parallel/test_cross_entropy.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
tests.nn.model_parallel.commons
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
set_random_seed
(
seed
)
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
F
.
cross_entropy
(
logits
.
view
(
-
1
,
logits
.
size
()[
-
1
]),
target
.
view
(
-
1
),
reduction
=
"none"
).
view_as
(
target
).
mean
()
loss
.
backward
()
return
loss
,
identity
.
weight
.
grad
def
mpu_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
set_random_seed
(
seed
)
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
logits_parallel
=
scatter_to_model_parallel_region
(
logits
)
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
vocab_parallel_cross_entropy
(
logits_parallel
,
target
).
mean
()
loss
.
backward
()
return
loss
,
identity
.
weight
.
grad
def
run_test_cross_entropy
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing cross entropy with model parallel size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
batch_size
=
13
seq_length
=
17
vocab_size_per_partition
=
11
logits_scale
=
1000.0
vocab_size
=
vocab_size_per_partition
*
model_parallel_size
seed
=
1234
loss_torch
,
grad_torch
=
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
)
loss_mpu
,
grad_mpu
=
mpu_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
)
error
=
loss_torch
.
sub_
(
loss_mpu
).
abs
().
max
()
print
(
" max error in loss on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
grad_torch
.
sub_
(
grad_mpu
).
abs
().
max
()
print
(
" max error in grad on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
test_cross_entropy
():
spawn_for_all_world_sizes
(
run_test_cross_entropy
)
tests/nn/model_parallel/test_initialize.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing initialize_model_parallel with size {} ..."
.
format
(
model_parallel_size
))
model_parallel_size_
=
min
(
model_parallel_size
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
model_parallel_size_
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks.
def
check
(
group
,
world_size
,
rank
):
assert
world_size
==
torch
.
distributed
.
get_world_size
(
group
=
group
)
assert
rank
==
torch
.
distributed
.
get_rank
(
group
=
group
)
# Model parallel.
world_size
=
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
model_parallel_size_
assert
world_size
==
mpu
.
get_model_parallel_world_size
()
assert
rank
==
mpu
.
get_model_parallel_rank
()
check
(
mpu
.
get_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
model_parallel_size
assert
world_size
==
mpu
.
get_data_parallel_world_size
()
assert
rank
==
mpu
.
get_data_parallel_rank
()
check
(
mpu
.
get_data_parallel_group
(),
world_size
,
rank
)
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_get_model_parallel_src_rank
(
rank
,
model_parallel_size_
):
dist_init
(
rank
,
model_parallel_size_
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing get_model_parallel_src_rank with size {} ..."
.
format
(
model_parallel_size_
))
model_parallel_size
=
min
(
model_parallel_size_
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
model_parallel_size
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_model_parallel_rank
()
assert
mpu
.
get_model_parallel_src_rank
()
==
src_rank
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
test_initialize_model_parallel
():
spawn_for_all_world_sizes
(
run_test_initialize_model_parallel
)
def
test_get_model_parallel_src_rank
():
spawn_for_all_world_sizes
(
run_test_get_model_parallel_src_rank
)
def
test_adjacency
(
monkeypatch
):
new_groups
=
[]
data_parallel_size
=
32
pipeline_length
=
8
model_parallel_size
=
4
class
MockDistribued
:
def
get_rank
(
self
):
return
0
def
is_initialized
(
self
):
return
True
def
get_world_size
(
self
):
return
data_parallel_size
*
pipeline_length
*
model_parallel_size
def
new_group
(
self
,
args
):
new_groups
.
append
(
args
.
copy
())
return
()
monkeypatch
.
setattr
(
torch
,
"distributed"
,
MockDistribued
())
mpu
.
initialize_model_parallel
(
model_parallel_size
,
pipeline_length
)
from
collections
import
defaultdict
buckets
=
defaultdict
(
list
)
for
group
in
new_groups
:
buckets
[
len
(
group
)].
append
(
group
)
assert
sorted
(
list
(
buckets
.
keys
()))
==
[
model_parallel_size
,
data_parallel_size
]
assert
len
(
buckets
[
model_parallel_size
])
==
pipeline_length
*
data_parallel_size
assert
len
(
buckets
[
data_parallel_size
])
==
model_parallel_size
*
pipeline_length
# Check that model_parallel groups are contiguous
for
group
in
buckets
[
model_parallel_size
]:
assert
sorted
(
group
)
==
group
assert
list
(
range
(
group
[
0
],
group
[
-
1
]
+
1
))
==
group
tests/nn/model_parallel/test_layers.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.pipe
import
Pipe
from
tests.nn.model_parallel.commons
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing parallel embedding with model parallel size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
batch_size
=
17
seq_length
=
23
vocab_size
=
48
hidden_size
=
16
seed
=
1236
set_random_seed
(
123
)
input_data
=
torch
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
seq_length
,
hidden_size
]).
cuda
()
set_random_seed
(
seed
)
embedding_original
=
torch
.
nn
.
Embedding
(
vocab_size
,
hidden_size
).
cuda
()
output
=
embedding_original
(
input_data
)
loss_original
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_original
.
backward
()
set_random_seed
(
seed
)
embedding_parallel
=
layers
.
ParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_parallel
(
input_data
)
loss_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_parallel
.
backward
()
set_random_seed
(
seed
)
embedding_vocab_parallel
=
layers
.
VocabParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_vocab_parallel
(
input_data
)
loss_vocab_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_vocab_parallel
.
backward
()
torch
.
distributed
.
barrier
()
error
=
loss_parallel
.
sub
(
loss_original
).
abs
()
print
(
" error in loss (parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
torch
.
distributed
.
barrier
()
error
=
loss_vocab_parallel
.
sub
(
loss_original
).
abs
()
print
(
" error in loss (vocab parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
hidden_size
//
model_parallel_size
,
1
)[
mpu
.
get_model_parallel_rank
()
]
error
=
embedding_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
" error in grad (parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
vocab_size
//
model_parallel_size
,
0
)[
mpu
.
get_model_parallel_rank
()
]
error
=
embedding_vocab_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
" error in grad (vocab parallel) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
"error: {}"
.
format
(
error
)
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_initialize_affine_weight
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing initialize_affine_weight with model parallel size: {}"
.
format
(
model_parallel_size
))
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed
=
12345
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
# ---------------
# Column parallel
# ---------------
weight
=
torch
.
empty
(
output_size_coeff
,
input_size
)
set_random_seed
(
seed
)
layers
.
_initialize_affine_weight
(
weight
,
output_size
,
input_size
,
output_size_coeff
,
0
,
torch
.
nn
.
init
.
normal_
)
# Target.
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
# Compare.
error
=
weight
.
sub
(
my_weight
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" column parallel max error (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# ------------
# Row parallel
# ------------
weight
=
torch
.
empty
(
output_size
,
input_size_coeff
)
set_random_seed
(
seed
)
layers
.
_initialize_affine_weight
(
weight
,
output_size
,
input_size
,
input_size_coeff
,
1
,
torch
.
nn
.
init
.
normal_
)
# Target.
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
# Compare.
error
=
weight
.
sub
(
my_weight
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" row parallel max error (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
" >> passed the test :-)"
)
class
IdentityLayer2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
):
super
(
IdentityLayer2D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
def
run_test_column_parallel_linear
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing ColumnParallelLinear with model parallel size: {}"
.
format
(
model_parallel_size
))
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
7
# Network
identity_layer
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
linear_layer
=
layers
.
ColumnParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
linear_layer
(
input_
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
# Values.
dLdY
=
loss_weight
X
=
identity_layer
.
weight
A
=
linear_layer
.
master_weight
.
cuda
()
dLdA
=
torch
.
matmul
(
dLdY
.
t
(),
X
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdA on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
my_dLdb
=
torch
.
split
(
dLdb
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdb
.
sub
(
linear_layer
.
bias
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdb on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdX
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdX on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
" >> passed the test :-)"
)
def
run_test_row_parallel_linear
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing RowParallelLinear with model parallel size: {}"
.
format
(
model_parallel_size
))
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
7
# Network
identity_layer
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
linear_layer
=
layers
.
RowParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
linear_layer
(
input_
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
# Values.
dLdY
=
loss_weight
X
=
identity_layer
.
weight
A
=
linear_layer
.
master_weight
.
cuda
()
dLdA
=
torch
.
matmul
(
dLdY
.
t
(),
X
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdA on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdb
.
sub
(
linear_layer
.
bias
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdb on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdX
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
" error in dLdX on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
" >> passed the test :-)"
)
def
run_test_pipe
(
rank
,
model_parallel_size
):
pipe_world_size
=
2
dist_init
(
rank
,
model_parallel_size
)
mpu
.
initialize_model_parallel
(
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}"
.
format
(
model_parallel_size
,
pipe_world_size
)
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
chunk_size
=
8
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
7
*
chunk_size
identity
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
pipeline_devices
=
mpu
.
get_pipeline_parallel_group
()
if
pipe_world_size
==
2
and
len
(
pipeline_devices
)
==
1
:
pipeline_devices
.
append
(
pipeline_devices
[
0
]
+
model_parallel_size
)
set_random_seed
(
seed
)
model
=
nn
.
Sequential
(
layers
.
ColumnParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
,
bias
=
False
).
cuda
(),
nn
.
ReLU
(),
layers
.
RowParallelLinear
(
output_size
,
input_size
,
keep_master_weight_for_test
=
True
,
bias
=
False
).
cuda
(),
)
set_random_seed
(
seed
)
reference
=
nn
.
Sequential
(
nn
.
Linear
(
input_size
,
output_size
,
bias
=
False
).
cuda
(),
nn
.
ReLU
(),
nn
.
Linear
(
output_size
,
input_size
,
bias
=
False
).
cuda
(),
)
reference
[
0
].
weight
.
data
=
model
[
0
].
master_weight
.
cuda
()
reference
[
-
1
].
weight
.
data
=
model
[
-
1
].
master_weight
.
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
output
=
model
(
identity
())
reference_output
=
reference
(
identity
())
error
=
reference_output
.
sub
(
output
).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
if
pipe_world_size
==
2
:
pipe_model
=
Pipe
(
model
,
[
2
,
1
],
devices
=
pipeline_devices
,
chunks
=
chunk_size
)
torch
.
distributed
.
barrier
()
pipe_output
=
pipe_model
(
identity
())
error
=
reference_output
.
sub
(
pipe_output
.
cuda
()).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
test_affine_weight
():
spawn_for_all_world_sizes
(
run_test_initialize_affine_weight
)
def
test_embedding
():
spawn_for_all_world_sizes
(
run_test_parallel_embedding
)
def
test_column_parallel
():
spawn_for_all_world_sizes
(
run_test_column_parallel_linear
)
def
test_row_parallel
():
spawn_for_all_world_sizes
(
run_test_row_parallel_linear
)
def
test_pipe
():
world_sizes
=
[
x
for
x
in
get_world_sizes
()
if
x
<=
torch
.
cuda
.
device_count
()
/
2
]
spawn_for_all_world_sizes
(
run_test_pipe
,
world_sizes
)
tests/nn/model_parallel/test_random.py
0 → 100644
View file @
30f5009a
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
random
from
fairscale.nn.model_parallel.random
import
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
from
tests.nn.model_parallel.commons
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_set_cuda_rng_state
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing set_rng_state with size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
size
=
123
seed
=
1234
torch
.
cuda
.
manual_seed
(
1234
)
tensor
=
torch
.
cuda
.
FloatTensor
(
size
)
# Get the state
rng_state
=
torch
.
cuda
.
get_rng_state
()
rng_state_copy
=
rng_state
.
clone
()
# Do some stuff.
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
result_1
=
tensor
.
clone
()
assert
rng_state
.
sub
(
rng_state_copy
).
max
()
==
0
assert
torch
.
cuda
.
get_rng_state
().
sub
(
rng_state_copy
).
max
()
>
0
# State should be different.
new_rng_state
=
torch
.
cuda
.
get_rng_state
()
max_diff
=
new_rng_state
.
sub
(
rng_state
).
max
()
print
(
" max diff in rng state (should be non-zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
max_diff
)
)
assert
max_diff
>
0
# Reset the rng state and do the same stuff.
random
.
_set_cuda_rng_state
(
rng_state
)
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
random
.
_set_cuda_rng_state
(
rng_state
)
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
result_2
=
tensor
.
clone
()
# Results should be the same
error
=
result_2
.
sub
(
result_1
).
abs
().
max
()
print
(
" max error in generated tensors (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# Input state should have remained intact.
error
=
rng_state
.
sub
(
rng_state_copy
).
max
()
print
(
" max error in rng state (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
==
0
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_cuda_rng_tracker
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing cuda rng tracker with size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
seed_1
=
1234
seed_2
=
4321
size
=
[
12
,
21
]
tensor
=
torch
.
cuda
.
FloatTensor
(
size
)
# Set to seed_1 and generate two tensors.
torch
.
cuda
.
manual_seed
(
seed_1
)
torch
.
randn
(
size
,
out
=
tensor
)
target_11
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
target_12
=
tensor
.
clone
()
# Set to seed_2 and generate two tensors.
torch
.
cuda
.
manual_seed
(
seed_2
)
torch
.
randn
(
size
,
out
=
tensor
)
target_21
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
target_22
=
tensor
.
clone
()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch
.
cuda
.
manual_seed
(
seed_1
)
get_cuda_rng_tracker
().
add
(
"test"
,
seed_2
)
torch
.
randn
(
size
,
out
=
tensor
)
result_11
=
tensor
.
clone
()
with
get_cuda_rng_tracker
().
fork
(
"test"
):
torch
.
randn
(
size
,
out
=
tensor
)
result_21
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
result_12
=
tensor
.
clone
()
with
get_cuda_rng_tracker
().
fork
(
"test"
):
torch
.
randn
(
size
,
out
=
tensor
)
result_22
=
tensor
.
clone
()
diff
=
result_11
.
sub
(
result_21
).
abs
().
max
()
diff
=
min
(
diff
,
result_12
.
sub
(
result_22
).
abs
().
max
())
print
(
" max diff in generated tensors (should be non-zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
diff
)
)
assert
diff
>
1.0e-6
error
=
max
(
result_11
.
sub
(
target_11
).
abs
().
max
(),
result_12
.
sub
(
target_12
).
abs
().
max
())
error
=
max
(
error
,
result_21
.
sub
(
target_21
).
abs
().
max
())
error
=
max
(
error
,
result_22
.
sub
(
target_22
).
abs
().
max
())
print
(
" max error in generated tensors (should be zero) on global rank {}: {}"
.
format
(
torch
.
distributed
.
get_rank
(),
error
)
)
assert
error
<
1.0e-6
# Reset the tracker
get_cuda_rng_tracker
().
reset
()
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
run_test_model_parallel_cuda_manual_seed
(
rank
,
model_parallel_size
):
dist_init
(
rank
,
model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing model parallel cuda manual seed with size {} ..."
.
format
(
model_parallel_size
))
mpu
.
initialize_model_parallel
(
model_parallel_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
model_parallel_cuda_manual_seed
(
12345
)
assert
torch
.
cuda
.
initial_seed
()
==
12345
with
get_cuda_rng_tracker
().
fork
():
assert
torch
.
cuda
.
initial_seed
()
==
(
12345
+
2718
+
mpu
.
get_model_parallel_rank
())
# Reset the tracker
get_cuda_rng_tracker
().
reset
()
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">> passed the test :-)"
)
def
test_set_cuda_rng_state
():
spawn_for_all_world_sizes
(
run_test_set_cuda_rng_state
)
def
test_cuda_rng_tracker
():
spawn_for_all_world_sizes
(
run_test_cuda_rng_tracker
)
def
test_model_parallel_cuda_manual_seed
():
spawn_for_all_world_sizes
(
run_test_model_parallel_cuda_manual_seed
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment