Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
df493a29
Unverified
Commit
df493a29
authored
Mar 22, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 22, 2021
Browse files
[ci][SDP] extending the test matrix which checks for equivalence with DDP (#542)
parent
fa1b85fb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
36 deletions
+57
-36
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+57
-36
No files found.
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
df493a29
...
...
@@ -34,13 +34,26 @@ _test_fp16_reduction = [False]
if
hasattr
(
dist
,
"algorithms.ddp_com_hooks.default_hooks"
):
_test_fp16_reduction
.
append
(
True
)
_test_amp
=
[
False
]
if
hasattr
(
torch
.
cuda
.
amp
,
"autocast"
):
_test_amp
.
append
(
True
)
def
_get_mlp
():
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
rank
,
world_size
,
backend
,
temp_file_name
,
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
,
):
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
@@ -51,7 +64,7 @@ def run_ddp_parity(
NUMBER_BATCHS
=
5
BATCH_SIZE
=
8
def
check_parity
(
amp
:
bool
,
manual_reduction
:
bool
):
def
check_parity
(
manual_reduction
:
bool
):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def
closure
(
model
,
scaler
,
input_tensor
,
should_accumulate
,
_manual_reduction
=
False
):
...
...
@@ -108,7 +121,7 @@ def run_ddp_parity(
ddp_model
.
register_comm_hook
(
state
=
None
,
hook
=
fp16_compress_hook
)
# type: ignore
ddp_scaler
=
TorchGradScaler
()
if
amp
else
None
sharded_
ddp_
scaler
=
ShardedGradScaler
()
if
amp
else
None
sharded_scaler
=
ShardedGradScaler
()
if
amp
else
None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params
(
sharded_ddp_model
,
ddp_model
)
...
...
@@ -117,35 +130,44 @@ def run_ddp_parity(
for
i
in
range
(
NUMBER_BATCHS
):
input_tensor
=
torch
.
rand
((
BATCH_SIZE
,
2
)).
to
(
device
)
def
closure
_ddp
(
input_tensor
=
input_tensor
):
def
ddp_
closure
(
input_tensor
=
input_tensor
):
return
closure
(
ddp_model
,
ddp_scaler
,
input_tensor
,
grad_accumulation
)
def
closure_sharded
(
input_tensor
=
input_tensor
):
def
sharded_closure
(
input_tensor
=
input_tensor
):
return
closure
(
sharded_ddp_model
,
sharded_
ddp_
scaler
,
sharded_scaler
,
input_tensor
,
grad_accumulation
,
_manual_reduction
=
manual_reduction
,
)
# Step/scale both
if
ddp_scaler
is
not
None
:
_
=
closure_ddp
(
input_tensor
)
ddp_scaler
.
step
(
ddp_optimizer
)
ddp_scaler
.
update
()
else
:
ddp_optimizer
.
step
(
closure
=
closure_ddp
)
if
sharded_ddp_scaler
is
not
None
:
_
=
closure_sharded
(
input_tensor
)
sharded_ddp_scaler
.
step
(
sharded_optimizer
)
sharded_ddp_scaler
.
update
()
else
:
sharded_optimizer
.
step
(
closure
=
closure_sharded
)
for
_scaler
,
_closure
,
_optimizer
in
(
(
ddp_scaler
,
ddp_closure
,
ddp_optimizer
),
(
sharded_scaler
,
sharded_closure
,
sharded_optimizer
),
):
if
_scaler
is
not
None
:
_
=
_closure
(
input_tensor
)
_scaler
.
step
(
_optimizer
)
_scaler
.
update
()
check_same_model_params
(
sharded_ddp_model
,
ddp_model
,
f
"Rank:
{
rank
}
- Step
{
i
}
broke"
)
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
if
clip_grad_norm
:
total_norm
=
torch
.
nn
.
utils
.
clip_grad_norm_
(
ddp_model
.
parameters
(),
0.3
,
norm_type
=
2.0
)
# type: ignore
if
not
torch
.
isnan
(
total_norm
):
oss_total_norm
=
sharded_optimizer
.
clip_grad_norm
(
0.3
,
norm_type
=
2.0
)
assert
torch
.
allclose
(
oss_total_norm
,
total_norm
,
atol
=
1e-2
if
amp
else
1e-8
),
f
"torch and fairscale should return the same grad norm
\n
{
oss_total_norm
}
vs
{
total_norm
}
"
else
:
print
(
rank
,
"NaN grad norm in DDP"
,
flush
=
True
)
# Flip the trainability of the first parameter back and forth
if
i
==
0
and
change_train_graph
:
next
(
sharded_ddp_model
.
parameters
()).
requires_grad
=
not
next
(
...
...
@@ -155,24 +177,19 @@ def run_ddp_parity(
check_same_model_params
(
sharded_ddp_model
,
ddp_model
,
f
"Rank:
{
rank
}
- Trainability refresh
{
i
}
broke"
)
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
amp_tests
=
[
False
]
if
hasattr
(
torch
.
cuda
.
amp
,
"autocast"
):
amp_tests
.
append
(
True
)
manual_reductions
=
[
False
,
True
]
if
not
grad_accumulation
and
not
change_train_graph
else
[
False
]
for
manual_reduction
in
manual_reductions
:
for
amp
in
amp_tests
:
print
(
f
"
Checking configuration: accumulate
{
grad_accumulation
}
"
f
"
{
rank
}
:
Checking configuration: accumulate
{
grad_accumulation
}
"
+
f
" - change train graph
{
change_train_graph
}
"
+
f
" - amp
{
amp
}
"
+
f
" - manual reduction
{
manual_reduction
}
"
+
f
" - buffers
{
reduce_buffer_size
}
"
,
flush
=
True
,
)
check_parity
(
amp
=
amp
,
manual_reduction
=
manual_reduction
,
)
check_parity
(
manual_reduction
=
manual_reduction
)
torch
.
cuda
.
synchronize
()
torch
.
distributed
.
barrier
(
)
dist
.
destroy_process_group
()
...
...
@@ -183,7 +200,9 @@ def run_ddp_parity(
@
pytest
.
mark
.
parametrize
(
"grad_accumulation"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"change_train_graph"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fp16_reduction"
,
_test_fp16_reduction
)
def
test_ddp_parity
(
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
):
@
pytest
.
mark
.
parametrize
(
"clip_grad_norm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"amp"
,
_test_amp
)
def
test_ddp_parity
(
reduce_buffer_size
,
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
):
world_size
=
torch
.
cuda
.
device_count
()
backend
=
dist
.
Backend
.
NCCL
mp
.
spawn
(
...
...
@@ -196,6 +215,8 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f
grad_accumulation
,
change_train_graph
,
fp16_reduction
,
clip_grad_norm
,
amp
,
),
nprocs
=
world_size
,
join
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment