Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
df493a29
Unverified
Commit
df493a29
authored
Mar 22, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 22, 2021
Browse files
[ci][SDP] extending the test matrix which checks for equivalence with DDP (#542)
parent
fa1b85fb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
36 deletions
+57
-36
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+57
-36
No files found.
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
df493a29
...
...
@@ -34,13 +34,26 @@ _test_fp16_reduction = [False]
if
hasattr
(
dist
,
"algorithms.ddp_com_hooks.default_hooks"
):
_test_fp16_reduction
.
append
(
True
)
_test_amp
=
[
False
]
if
hasattr
(
torch
.
cuda
.
amp
,
"autocast"
):
_test_amp
.
append
(
True
)
def
_get_mlp
():
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
,
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
@@ -51,7 +64,7 @@ def run_ddp_parity(
NUMBER_BATCHS
=
5
BATCH_SIZE
=
8
def
check_parity
(
amp
:
bool
,
manual_reduction
:
bool
):
def
check_parity
(
manual_reduction
:
bool
):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def
closure
(
model
,
scaler
,
input_tensor
,
should_accumulate
,
_manual_reduction
=
False
):
...
...
@@ -108,7 +121,7 @@ def run_ddp_parity(
ddp_model
.
register_comm_hook
(
state
=
None
,
hook
=
fp16_compress_hook
)
# type: ignore
ddp_scaler
=
TorchGradScaler
()
if
amp
else
None
sharded_
ddp_
scaler
=
ShardedGradScaler
()
if
amp
else
None
sharded_scaler
=
ShardedGradScaler
()
if
amp
else
None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params
(
sharded_ddp_model
,
ddp_model
)
...
...
@@ -117,35 +130,44 @@ def run_ddp_parity(
for
i
in
range
(
NUMBER_BATCHS
):
input_tensor
=
torch
.
rand
((
BATCH_SIZE
,
2
)).
to
(
device
)
def
closure
_ddp
(
input_tensor
=
input_tensor
):
def
ddp_
closure
(
input_tensor
=
input_tensor
):
return
closure
(
ddp_model
,
ddp_scaler
,
input_tensor
,
grad_accumulation
)
def
closure_sharded
(
input_tensor
=
input_tensor
):
def
sharded_closure
(
input_tensor
=
input_tensor
):
return
closure
(
sharded_ddp_model
,
sharded_
ddp_
scaler
,
sharded_scaler
,
input_tensor
,
grad_accumulation
,
_manual_reduction
=
manual_reduction
,
)
# Step/scale both
if
ddp_scaler
is
not
None
:
_
=
closure_ddp
(
input_tensor
)
ddp_scaler
.
step
(
ddp_optimizer
)
ddp_scaler
.
update
()
else
:
ddp_optimizer
.
step
(
closure
=
closure_ddp
)
if
sharded_ddp_scaler
is
not
None
:
_
=
closure_sharded
(
input_tensor
)
sharded_ddp_scaler
.
step
(
sharded_optimizer
)
sharded_ddp_scaler
.
update
()
else
:
sharded_optimizer
.
step
(
closure
=
closure_sharded
)
for
_scaler
,
_closure
,
_optimizer
in
(
(
ddp_scaler
,
ddp_closure
,
ddp_optimizer
),
(
sharded_scaler
,
sharded_closure
,
sharded_optimizer
),
):
if
_scaler
is
not
None
:
_
=
_closure
(
input_tensor
)
_scaler
.
step
(
_optimizer
)
_scaler
.
update
()
check_same_model_params
(
sharded_ddp_model
,
ddp_model
,
f
"Rank:
{
rank
}
- Step
{
i
}
broke"
)
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
if
clip_grad_norm
:
total_norm
=
torch
.
nn
.
utils
.
clip_grad_norm_
(
ddp_model
.
parameters
(),
0.3
,
norm_type
=
2.0
)
# type: ignore
if
not
torch
.
isnan
(
total_norm
):
oss_total_norm
=
sharded_optimizer
.
clip_grad_norm
(
0.3
,
norm_type
=
2.0
)
assert
torch
.
allclose
(
oss_total_norm
,
total_norm
,
atol
=
1e-2
if
amp
else
1e-8
),
f
"torch and fairscale should return the same grad norm
\n
{
oss_total_norm
}
vs
{
total_norm
}
"
else
:
print
(
rank
,
"NaN grad norm in DDP"
,
flush
=
True
)
# Flip the trainability of the first parameter back and forth
if
i
==
0
and
change_train_graph
:
next
(
sharded_ddp_model
.
parameters
()).
requires_grad
=
not
next
(
...
...
@@ -155,24 +177,19 @@ def run_ddp_parity(
check_same_model_params
(
sharded_ddp_model
,
ddp_model
,
f
"Rank:
{
rank
}
- Trainability refresh
{
i
}
broke"
)
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
amp_tests
=
[
False
]
if
hasattr
(
torch
.
cuda
.
amp
,
"autocast"
):
amp_tests
.
append
(
True
)
manual_reductions
=
[
False
,
True
]
if
not
grad_accumulation
and
not
change_train_graph
else
[
False
]
for
manual_reduction
in
manual_reductions
:
for
amp
in
amp_tests
:
print
(
f
"
Checking configuration: accumulate
{
grad_accumulation
}
"
f
"
{
rank
}
:
Checking configuration: accumulate
{
grad_accumulation
}
"
+
f
" - change train graph
{
change_train_graph
}
"
+
f
" - amp
{
amp
}
"
+
f
" - manual reduction
{
manual_reduction
}
"
+
f
" - buffers
{
reduce_buffer_size
}
"
,
flush
=
True
,
)
check_parity
(
amp
=
amp
,
manual_reduction
=
manual_reduction
,
)
check_parity
(
manual_reduction
=
manual_reduction
)
torch
.
cuda
.
synchronize
()
torch
.
distributed
.
barrier
(
)
dist
.
destroy_process_group
()
...
...
@@ -183,7 +200,9 @@ def run_ddp_parity(
@
pytest
.
mark
.
parametrize
(
"grad_accumulation"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"change_train_graph"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fp16_reduction"
,
_test_fp16_reduction
)
def
test_ddp_parity
(
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
):
@
pytest
.
mark
.
parametrize
(
"clip_grad_norm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"amp"
,
_test_amp
)
def
test_ddp_parity
(
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
):
world_size
=
torch
.
cuda
.
device_count
()
backend
=
dist
.
Backend
.
NCCL
mp
.
spawn
(
...
...
@@ -196,6 +215,8 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
,
),
nprocs
=
world_size
,
join
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment