Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3b7373e2
Unverified
Commit
3b7373e2
authored
Apr 29, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 29, 2021
Browse files
[test][refactor][SDP] Using the nice context-based tempfiles (#640)
parent
8c8a625a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
146 additions
and
132 deletions
+146
-132
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+113
-106
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+33
-26
No files found.
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
3b7373e2
...
@@ -8,7 +8,6 @@ Testing ShardedDDP
...
@@ -8,7 +8,6 @@ Testing ShardedDDP
"""
"""
from
contextlib
import
suppress
from
contextlib
import
suppress
import
tempfile
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -27,6 +26,7 @@ from fairscale.utils.testing import (
...
@@ -27,6 +26,7 @@ from fairscale.utils.testing import (
skip_if_less_than_four_gpu
,
skip_if_less_than_four_gpu
,
skip_if_no_cuda
,
skip_if_no_cuda
,
skip_if_single_gpu
,
skip_if_single_gpu
,
temp_files_ctx
,
)
)
...
@@ -134,10 +134,10 @@ def run_one_step(
...
@@ -134,10 +134,10 @@ def run_one_step(
def
run_test
(
backend
,
device
,
world_size
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
):
def
run_test
(
backend
,
device
,
world_size
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
):
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
mp
.
spawn
(
mp
.
spawn
(
run_one_step
,
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
...
@@ -160,15 +160,14 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
...
@@ -160,15 +160,14 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
)
)
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
setup
):
def
test_step
(
broadcast_buffers
,
grad_accumulation
,
reduce_buffer_size
,
optimizer_type
,
reduce_fp16
,
setup
):
world_size
=
2
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_one_step
,
run_one_step
,
args
=
(
args
=
(
world_size
,
world_size
,
setup
[
0
],
setup
[
0
],
setup
[
1
],
setup
[
1
],
temp_file
_name
,
temp_file
s
[
0
]
,
broadcast_buffers
,
broadcast_buffers
,
grad_accumulation
,
grad_accumulation
,
reduce_buffer_size
,
reduce_buffer_size
,
...
@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
...
@@ -200,7 +199,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
loss
.
backward
()
loss
.
backward
()
return
loss
return
loss
for
i
in
range
(
5
):
for
_
in
range
(
5
):
_
=
optimizer
.
step
(
closure
=
closure
)
_
=
optimizer
.
step
(
closure
=
closure
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
@@ -215,10 +214,10 @@ def test_inputs(reduce_buffer_size, backend, device):
...
@@ -215,10 +214,10 @@ def test_inputs(reduce_buffer_size, backend, device):
if
backend
==
"nccl"
and
device
==
"cpu"
:
if
backend
==
"nccl"
and
device
==
"cpu"
:
pytest
.
skip
(
"Incompatible combination, or cuda not available"
)
pytest
.
skip
(
"Incompatible combination, or cuda not available"
)
return
return
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_test_two_inputs
,
run_test_two_inputs
,
args
=
(
world_size
,
backend
,
device
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp
_
file
s
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
...
@@ -228,7 +227,8 @@ def test_ddp_attributes():
...
@@ -228,7 +227,8 @@ def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - is multi_device_module
# - device_type
# - device_type
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_files
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
...
@@ -240,8 +240,9 @@ def test_ddp_attributes():
...
@@ -240,8 +240,9 @@ def test_ddp_attributes():
def
test_random_attributes
():
def
test_random_attributes
():
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp
_
file
s
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
banana
=
"sweet"
model
.
banana
=
"sweet"
...
@@ -256,8 +257,9 @@ def test_random_attributes():
...
@@ -256,8 +257,9 @@ def test_random_attributes():
def
test_catch_grad_grad
():
def
test_catch_grad_grad
():
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp
_
file
s
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
train
()
model
.
train
()
...
@@ -276,8 +278,9 @@ def test_catch_grad_grad():
...
@@ -276,8 +278,9 @@ def test_catch_grad_grad():
def
test_mixed_types
():
def
test_mixed_types
():
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
# Check that ShardedDDP exposes the original module's attributes
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp
_
file
s
[
0
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
_get_mlp
(
tripwire
=
True
)
model
=
_get_mlp
(
tripwire
=
True
)
...
@@ -317,9 +320,9 @@ def run_test_train_eval_change(rank, world_size, file):
...
@@ -317,9 +320,9 @@ def run_test_train_eval_change(rank, world_size, file):
def
test_train_eval_change
():
def
test_train_eval_change
():
world_size
=
4
world_size
=
4
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
mp
.
spawn
(
mp
.
spawn
(
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file
_name
),
nprocs
=
world_size
,
join
=
True
,
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file
s
[
0
]
),
nprocs
=
world_size
,
join
=
True
,
)
)
...
@@ -352,11 +355,11 @@ def test_device_change(reduce_buffer_size):
...
@@ -352,11 +355,11 @@ def test_device_change(reduce_buffer_size):
# Check that ShardedDDP handles a device change properly
# Check that ShardedDDP handles a device change properly
world_size
=
2
world_size
=
2
backend
=
"nccl"
backend
=
"nccl"
temp_file
_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_file
s_ctx
(
num
=
1
)
as
temp
_
file
s
:
device
=
"cuda"
device
=
"cuda"
mp
.
spawn
(
mp
.
spawn
(
run_test_device_change
,
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
reduce_buffer_size
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
...
@@ -389,11 +392,11 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name,
...
@@ -389,11 +392,11 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name,
def
test_training_change
(
reduce_buffer_size
):
def
test_training_change
(
reduce_buffer_size
):
world_size
=
2
world_size
=
2
backend
=
"nccl"
backend
=
"nccl"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
device
=
"cuda"
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_test_training_change
,
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file
_name
,
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
device
,
temp_file
s
[
0
]
,
reduce_buffer_size
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
...
@@ -421,10 +424,13 @@ def test_ddp_sync_batch_norm():
...
@@ -421,10 +424,13 @@ def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size
=
2
world_size
=
2
backend
=
"gloo"
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
device
=
"cuda"
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
,
)
)
...
@@ -463,9 +469,11 @@ def test_two_optimizers():
...
@@ -463,9 +469,11 @@ def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size
=
2
world_size
=
2
backend
=
"gloo"
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cpu"
device
=
"cpu"
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_two_optimizers
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
)
def
run_test_gpt2
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
def
run_test_gpt2
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
...
@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
...
@@ -510,9 +518,9 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
def
test_gpt2
(
world_size
):
def
test_gpt2
(
world_size
):
# Check that having trainable unused params is fine
# Check that having trainable unused params is fine
backend
=
"gloo"
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
device
=
"cuda"
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
run_test_gpt2
,
args
=
(
world_size
,
backend
,
device
,
temp_files
[
0
]),
nprocs
=
world_size
,
join
=
True
)
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
,
backend
,
reduce_buffer_size
):
def
run_test_multiple_groups
(
rank
,
world_size
,
tempfile_name
,
backend
,
reduce_buffer_size
):
...
@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu
...
@@ -575,11 +583,10 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_bu
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"gloo"
,
"nccl"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"gloo"
,
"nccl"
])
def
test_multiple_groups
(
reduce_buffer_size
,
backend
):
def
test_multiple_groups
(
reduce_buffer_size
,
backend
):
world_size
=
4
world_size
=
4
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_test_multiple_groups
,
run_test_multiple_groups
,
args
=
(
world_size
,
temp_file
_name
,
backend
,
reduce_buffer_size
),
args
=
(
world_size
,
temp_file
s
[
0
]
,
backend
,
reduce_buffer_size
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
3b7373e2
...
@@ -9,7 +9,6 @@ Testing ShardedDDP
...
@@ -9,7 +9,6 @@ Testing ShardedDDP
from
contextlib
import
suppress
from
contextlib
import
suppress
import
copy
import
copy
import
tempfile
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -23,7 +22,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
torch_version
from
fairscale.utils.testing
import
(
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
,
torch_version
,
)
"""
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
...
@@ -250,12 +255,13 @@ def test_ddp_parity(
...
@@ -250,12 +255,13 @@ def test_ddp_parity(
world_size
=
torch
.
cuda
.
device_count
()
world_size
=
torch
.
cuda
.
device_count
()
backend
=
dist
.
Backend
.
NCCL
backend
=
dist
.
Backend
.
NCCL
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_ddp_parity
,
run_ddp_parity
,
args
=
(
args
=
(
world_size
,
world_size
,
backend
,
backend
,
tempfile
.
mkstemp
()[
1
],
temp
_
file
s
[
0
],
reduce_buffer_size
,
reduce_buffer_size
,
grad_accumulation
,
grad_accumulation
,
change_train_graph
,
change_train_graph
,
...
@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
...
@@ -340,9 +346,10 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
def
test_ddp_parity_two_optim
(
reduce_buffer_size
):
def
test_ddp_parity_two_optim
(
reduce_buffer_size
):
world_size
=
2
world_size
=
2
backend
=
dist
.
Backend
.
NCCL
backend
=
dist
.
Backend
.
NCCL
with
temp_files_ctx
(
num
=
1
)
as
temp_files
:
mp
.
spawn
(
mp
.
spawn
(
run_ddp_parity_two_optim
,
run_ddp_parity_two_optim
,
args
=
(
world_size
,
backend
,
tempfile
.
mkstemp
()[
1
],
reduce_buffer_size
),
args
=
(
world_size
,
backend
,
temp
_
file
s
[
0
],
reduce_buffer_size
),
nprocs
=
world_size
,
nprocs
=
world_size
,
join
=
True
,
join
=
True
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment