Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
98223763
Unverified
Commit
98223763
authored
Mar 17, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 17, 2021
Browse files
[refactor] removing duplicated tests (#529)
parent
39a12a8b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
351 deletions
+1
-351
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+1
-351
No files found.
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
98223763
...
@@ -23,15 +23,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -23,15 +23,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
GPT2
,
available_devices
,
check_same_model_params
,
check_same_models_across_ranks
,
skip_if_less_than_four_gpu
,
skip_if_no_cuda
,
skip_if_single_gpu
,
)
"""
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
...
@@ -47,17 +39,6 @@ def _get_mlp():
...
@@ -47,17 +39,6 @@ def _get_mlp():
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
class
_DoubleInput
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
mlp
=
_get_mlp
()
def
forward
(
self
,
x
,
y
):
x1
=
self
.
mlp
(
x
)
x2
=
self
.
mlp
(
y
)
return
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
def
run_ddp_parity
(
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
):
):
...
@@ -297,334 +278,3 @@ def test_ddp_parity_two_optim(reduce_buffer_size):
...
@@ -297,334 +278,3 @@ def test_ddp_parity_two_optim(reduce_buffer_size):
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
def
run_test_two_inputs
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
if
device
==
"cuda"
:
torch
.
cuda
.
set_device
(
rank
)
torch
.
manual_seed
(
rank
)
np
.
random
.
seed
(
rank
)
model
=
_DoubleInput
().
to
(
device
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
reduce_buffer_size
=
reduce_buffer_size
)
# Optim loop
def
closure
():
optimizer
.
zero_grad
()
input_tensor
=
torch
.
rand
((
64
,
2
)).
to
(
device
)
loss
=
ddp_model
(
input_tensor
,
input_tensor
).
abs
().
sum
()
loss
.
backward
()
return
loss
for
i
in
range
(
5
):
_
=
optimizer
.
step
(
closure
=
closure
)
dist
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"gloo"
,
"nccl"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
available_devices
)
def
test_inputs
(
reduce_buffer_size
,
backend
,
device
):
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size
=
2
if
backend
==
"nccl"
and
device
==
"cpu"
:
pytest
.
skip
(
"Incompatible combination, or cuda not available"
)
return
mp
.
spawn
(
run_test_two_inputs
,
args
=
(
world_size
,
backend
,
device
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
def
test_ddp_attributes
():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
dist
.
destroy_process_group
()
def
test_random_attributes
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
banana
=
"sweet"
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
assert
hasattr
(
ddp_model
,
"banana"
)
assert
not
hasattr
(
ddp_model
,
"orange"
)
dist
.
destroy_process_group
()
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
# Check that the wrapped module can change devices
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
)).
cpu
()
# not device on purpose, test changing it after the fact
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
sync_models_at_startup
=
False
,
reduce_buffer_size
=
reduce_buffer_size
)
try
:
ddp_model
.
to
(
device
)
assert
False
,
"Changing devices should be caught and not supported"
except
AssertionError
:
pass
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_single_gpu
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
def
test_device_change
(
reduce_buffer_size
):
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
backend
=
"nccl"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_training_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
group
=
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
)).
to
(
device
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
process_group
=
group
,
reduce_buffer_size
=
reduce_buffer_size
)
inputs
=
torch
.
rand
((
10
,
2
),
device
=
device
)
outputs
=
ddp_model
(
inputs
)
# assert if the module has not been changed properly
_
=
outputs
.
norm
().
backward
()
ddp_model
.
eval
()
ddp_model
(
inputs
)
# This will assert if eval() is not properly taken into account
ddp_model
(
inputs
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_single_gpu
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
def
test_training_change
(
reduce_buffer_size
):
world_size
=
2
backend
=
"nccl"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
model
=
Sequential
(
Linear
(
2
,
3
),
torch
.
nn
.
BatchNorm1d
(
3
),
Linear
(
3
,
3
)).
to
(
device
)
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
model
.
to
(
device
)
# in pytorch 1.5 syncBN switches to the default device/cpu
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
assert
isinstance
(
model
[
1
],
torch
.
nn
.
SyncBatchNorm
)
# Ensures sync batch norm handles have been added
ddp_model
(
torch
.
randn
(
2
,
2
).
to
(
device
))
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_single_gpu
def
test_ddp_sync_batch_norm
():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_two_optimizers
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
if
device
==
torch
.
device
(
"cuda"
):
torch
.
cuda
.
set_device
(
rank
)
torch
.
manual_seed
(
rank
)
np
.
random
.
seed
(
rank
)
model
=
_DoubleInput
().
to
(
device
)
parameters
=
list
(
model
.
parameters
())
optimizer_1
=
OSS
(
params
=
parameters
[:
-
10
],
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
optimizer_2
=
OSS
(
params
=
parameters
[
-
10
:],
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
[
optimizer_1
,
optimizer_2
])
# Optim loop
def
closure
():
input_tensor
=
torch
.
rand
((
64
,
2
)).
to
(
device
)
loss
=
ddp_model
(
input_tensor
,
input_tensor
).
abs
().
sum
()
loss
.
backward
()
return
loss
for
i
in
range
(
5
):
optimizer_1
.
zero_grad
()
optimizer_2
.
zero_grad
()
_
=
optimizer_1
.
step
(
closure
=
closure
)
_
=
optimizer_2
.
step
(
closure
=
closure
)
dist
.
destroy_process_group
()
def
test_two_optimizers
():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cpu"
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_gpt2
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
INPUT_DIM
=
16
BACH_SIZE
=
10
STEPS
=
10
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
torch
.
manual_seed
(
rank
)
np
.
random
.
seed
(
rank
)
model
=
GPT2
(
embed_dim
=
256
,
num_heads
=
2
,
num_layers
=
12
,
num_positions
=
INPUT_DIM
*
INPUT_DIM
,
num_vocab
=
512
,
num_classes
=
2
).
to
(
device
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
reduce_buffer_size
=
0
)
# Optim loop
def
closure
():
optimizer
.
zero_grad
()
# Force int inputs to prevent the first grad from firing
input_tensor
=
torch
.
randint
(
10
,
(
BACH_SIZE
,
INPUT_DIM
)).
to
(
device
)
loss
=
ddp_model
(
input_tensor
).
abs
().
sum
()
loss
.
backward
()
return
loss
# Check for bucketing overflows
for
i
in
range
(
STEPS
):
_
=
optimizer
.
step
(
closure
=
closure
)
dist
.
destroy_process_group
()
@
skip_if_no_cuda
@
skip_if_single_gpu
def
test_gpt2
():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
,
backend
,
reduce_buffer_size
):
# Only work with the even ranks, to check that the global_rank indexing is properly used
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
sub_group_ranks
=
[
0
,
2
]
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
backend
)
# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch
.
manual_seed
(
rank
)
np
.
random
.
seed
(
rank
)
# Standard deep learning setup
device
=
"cuda"
torch
.
cuda
.
set_device
(
rank
)
epochs
,
batch
,
input_width
,
hidden
,
target_width
=
5
,
3
,
20
,
10
,
5
loss_fn
=
torch
.
nn
.
L1Loss
().
to
(
device
)
def
check
(
optimizer
,
model
):
# Just run a couple of epochs, check that the model is properly updated
for
_
in
range
(
epochs
):
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
def
closure
():
optimizer
.
zero_grad
()
output
=
model
(
inputs
)
loss
=
loss_fn
(
output
,
target
)
loss
.
backward
()
return
loss
_
=
optimizer
.
step
(
closure
=
closure
)
# Check that all the params are the same on all ranks
check_same_models_across_ranks
(
model
,
process_group
,
params_should_be_equal
=
True
,
check_broadcast_buffers
=
True
)
if
rank
in
sub_group_ranks
:
# Model not-fitting in the broadcast bucket
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)).
to
(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer
=
OSS
(
model
.
parameters
(),
group
=
process_group
,
lr
=
1e-3
,
momentum
=
0.99
)
model
=
ShardedDataParallel
(
model
,
optimizer
,
process_group
=
process_group
,
reduce_buffer_size
=
reduce_buffer_size
)
check
(
optimizer
,
model
)
dist
.
destroy_process_group
(
process_group
)
@
skip_if_less_than_four_gpu
@
pytest
.
mark
.
parametrize
(
"reduce_buffer_size"
,
[
0
,
2
**
20
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"gloo"
,
"nccl"
])
def
test_multiple_groups
(
reduce_buffer_size
,
backend
):
world_size
=
4
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file_name
,
backend
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment