Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3b7373e2
Unverified
Commit
3b7373e2
authored
Apr 29, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 29, 2021
Browse files
[test][refactor][SDP] Using the nice context-based tempfiles (#640)
parent
8c8a625a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
146 additions
and
132 deletions
+146
-132
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+113
-106
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+33
-26
No files found.
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
3b7373e2
...
...
@@ -8,7 +8,6 @@ Testing ShardedDDP
"""
from
contextlib
import
suppress
import
tempfile
import
numpy
as
np
import
pytest
...
...
@@ -27,6 +26,7 @@ from fairscale.utils.testing import (
skip_if_less_than_four_gpu
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
,
)
...
...
@@ -134,13 +134,13 @@ def run_one_step(
def
run_test
(
backend
,
device
,
world_size
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
):
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
@
skip_if_no_cuda
...
...
@@ -160,24 +160,23 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
)
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
setup
):
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
setup
[
0
],
setup
[
1
],
temp_file_name
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
setup
[
0
],
setup
[
1
],
temp_files
[
0
],
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_two_inputs
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
...
...
@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
loss
.
backward
()
return
loss
for
i
in
range
(
5
):
for
_
in
range
(
5
):
_
=
optimizer
.
step
(
closure
=
closure
)
dist
.
destroy_process_group
()
...
...
@@ -215,78 +214,82 @@ def test_inputs(reduce_buffer_size, backend, device):
if
backend
==
"nccl"
and
device
==
"cpu"
:
pytest
.
skip
(
"Incompatible combination, or cuda not available"
)
return
mp
.
spawn
(
run_test_two_inputs
,
args
=
(
world_size
,
backend
,
device
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_two_inputs
,
args
=
(
world_size
,
backend
,
device
,
temp
_
file
s
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
def
test_ddp_attributes
():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_files
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
dist
.
destroy_process_group
()
assert
hasattr
(
ddp_model
,
"is_multi_device_module"
)
assert
hasattr
(
ddp_model
,
"device_type"
)
dist
.
destroy_process_group
()
def
test_random_attributes
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_files
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
banana
=
"sweet"
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
banana
=
"sweet"
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
assert
hasattr
(
ddp_model
,
"banana"
)
assert
not
hasattr
(
ddp_model
,
"orange"
)
assert
hasattr
(
ddp_model
,
"banana"
)
assert
not
hasattr
(
ddp_model
,
"orange"
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
test_catch_grad_grad
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_files
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
train
()
chained_grad
=
torch
.
zeros_like
(
next
(
model
.
parameters
()))
chained_grad
.
requires_grad
=
True
next
(
model
.
parameters
()).
grad
=
chained_grad
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
train
()
chained_grad
=
torch
.
zeros_like
(
next
(
model
.
parameters
()))
chained_grad
.
requires_grad
=
True
next
(
model
.
parameters
()).
grad
=
chained_grad
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
inputs
=
torch
.
rand
(
100
,
2
)
with
pytest
.
raises
(
RuntimeError
):
_
=
ddp_model
(
inputs
)
inputs
=
torch
.
rand
(
100
,
2
)
with
pytest
.
raises
(
RuntimeError
):
_
=
ddp_model
(
inputs
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
test_mixed_types
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_files
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
_get_mlp
(
tripwire
=
True
)
model
=
_get_mlp
(
tripwire
=
True
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
model
=
ShardedDataParallel
(
model
,
optimizer
)
input_tensor
=
torch
.
rand
((
2
,
2
))
_
=
model
(
input_tensor
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
model
=
ShardedDataParallel
(
model
,
optimizer
)
input_tensor
=
torch
.
rand
((
2
,
2
))
_
=
model
(
input_tensor
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
run_test_train_eval_change
(
rank
,
world_size
,
file
):
...
...
@@ -317,10 +320,10 @@ def run_test_train_eval_change(rank, world_size, file):
def
test_train_eval_change
():
world_size
=
4
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file
_name
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
mp
.
spawn
(
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file
s
[
0
]
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
...
...
@@ -352,14 +355,14 @@ def test_device_change(reduce_buffer_size):
# Check that ShardedDDP handles a device change properly
world_size
=
2
backend
=
"nccl"
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
device
=
"cuda"
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_training_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
...
...
@@ -389,14 +392,14 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name,
def
test_training_change
(
reduce_buffer_size
):
world_size
=
2
backend
=
"nccl"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
...
...
@@ -421,11 +424,14 @@ def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_two_optimizers
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
...
...
@@ -463,9 +469,11 @@ def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cpu"
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
)
def
run_test_gpt2
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
...
...
@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
def
test_gpt2
(
world_size
):
# Check that having trainable unused params is fine
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
)
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
,
backend
,
reduce_buffer_size
):
...
...
@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"gloo"
,
"nccl"
])
def
test_multiple_groups
(
reduce_buffer_size
,
backend
):
world_size
=
4
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file_name
,
backend
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,
temp_files
[
0
],
backend
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
3b7373e2
...
...
@@ -9,7 +9,6 @@ Testing ShardedDDP
from
contextlib
import
suppress
import
copy
import
tempfile
import
numpy
as
np
import
pytest
...
...
@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
torch_version
from
fairscale.utils.testing
import
(
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
,
torch_version
,
)
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
...
...
@@ -250,24 +255,25 @@ def test_ddp_parity(
world_size
=
torch
.
cuda
.
device_count
()
backend
=
dist
.
Backend
.
NCCL
mp
.
spawn
(
run_ddp_parity
,
args
=
(
world_size
,
backend
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
,
manual_reduction
,
multiple_fw
,
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_ddp_parity
,
args
=
(
world_size
,
backend
,
temp_files
[
0
],
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
,
manual_reduction
,
multiple_fw
,
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_ddp_parity_two_optim
(
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
):
...
...
@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
def
test_ddp_parity_two_optim
(
reduce_buffer_size
):
world_size
=
2
backend
=
dist
.
Backend
.
NCCL
mp
.
spawn
(
run_ddp_parity_two_optim
,
args
=
(
world_size
,
backend
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_ddp_parity_two_optim
,
args
=
(
world_size
,
backend
,
temp_files
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment