Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3b7373e2
Unverified
Commit
3b7373e2
authored
Apr 29, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 29, 2021
Browse files
[test][refactor][SDP] Using the nice context-based tempfiles (#640)
parent
8c8a625a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
146 additions
and
132 deletions
+146
-132
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+113
-106
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+33
-26
No files found.
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
3b7373e2
...
...
@@ -8,7 +8,6 @@ Testing ShardedDDP
"""
from
contextlib
import
suppress
import
tempfile
import
numpy
as
np
import
pytest
...
...
@@ -27,6 +26,7 @@ from fairscale.utils.testing import (
skip_if_less_than_four_gpu
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
,
)
...
...
@@ -134,10 +134,10 @@ def run_one_step(
def
run_test
(
backend
,
device
,
world_size
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
):
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -160,15 +160,14 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
)
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
setup
):
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
setup
[
0
],
setup
[
1
],
temp_file
_name
,
temp_file
s
[
0
]
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
...
...
@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
loss
.
backward
()
return
loss
for
i
in
range
(
5
):
for
_
in
range
(
5
):
_
=
optimizer
.
step
(
closure
=
closure
)
dist
.
destroy_process_group
()
...
...
@@ -215,10 +214,10 @@ def test_inputs(reduce_buffer_size, backend, device):
if
backend
==
"nccl"
and
device
==
"cpu"
:
pytest
.
skip
(
"Incompatible combination, or cuda not available"
)
return
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_two_inputs
,
args
=
(
world_size
,
backend
,
device
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp
_
file
s
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -228,7 +227,8 @@ def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_files
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
...
...
@@ -240,8 +240,9 @@ def test_ddp_attributes():
def
test_random_attributes
():
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp
_
file
s
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
banana
=
"sweet"
...
...
@@ -256,8 +257,9 @@ def test_random_attributes():
def
test_catch_grad_grad
():
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp
_
file
s
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
train
()
...
...
@@ -276,8 +278,9 @@ def test_catch_grad_grad():
def
test_mixed_types
():
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp
_
file
s
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
_get_mlp
(
tripwire
=
True
)
...
...
@@ -317,9 +320,9 @@ def run_test_train_eval_change(rank, world_size, file):
def
test_train_eval_change
():
world_size
=
4
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
mp
.
spawn
(
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file
_name
),
nprocs
=
world_size
,
join
=
True
,
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file
s
[
0
]
),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -352,11 +355,11 @@ def test_device_change(reduce_buffer_size):
# Check that ShardedDDP handles a device change properly
world_size
=
2
backend
=
"nccl"
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
device
=
"cuda"
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -389,11 +392,11 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name,
def
test_training_change
(
reduce_buffer_size
):
world_size
=
2
backend
=
"nccl"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -421,10 +424,13 @@ def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -463,9 +469,11 @@ def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size
=
2
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cpu"
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
)
def
run_test_gpt2
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
...
...
@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
def
test_gpt2
(
world_size
):
# Check that having trainable unused params is fine
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
)
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
,
backend
,
reduce_buffer_size
):
...
...
@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"gloo"
,
"nccl"
])
def
test_multiple_groups
(
reduce_buffer_size
,
backend
):
world_size
=
4
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file
_name
,
backend
,
reduce_buffer_size
),
args
=
(
world_size
,
temp_file
s
[
0
]
,
backend
,
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
3b7373e2
...
...
@@ -9,7 +9,6 @@ Testing ShardedDDP
from
contextlib
import
suppress
import
copy
import
tempfile
import
numpy
as
np
import
pytest
...
...
@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
torch_version
from
fairscale.utils.testing
import
(
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
,
torch_version
,
)
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
...
...
@@ -250,12 +255,13 @@ def test_ddp_parity(
world_size
=
torch
.
cuda
.
device_count
()
backend
=
dist
.
Backend
.
NCCL
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_ddp_parity
,
args
=
(
world_size
,
backend
,
tempfile
.
mkstemp
()[
1
],
temp
_
file
s
[
0
],
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
...
...
@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
def
test_ddp_parity_two_optim
(
reduce_buffer_size
):
world_size
=
2
backend
=
dist
.
Backend
.
NCCL
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_ddp_parity_two_optim
,
args
=
(
world_size
,
backend
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
temp
_
file
s
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment