Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
60c8de4a
Unverified
Commit
60c8de4a
authored
Dec 28, 2020
by
Joshua Meier
Committed by
GitHub
Dec 28, 2020
Browse files
[feature] OSS: add unit test for distributed checkpointing (#273)
author: Joshua Meier
parent
b640cab5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
0 deletions
+98
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+98
-0
No files found.
tests/optim/test_oss.py
View file @
60c8de4a
...
@@ -565,3 +565,101 @@ def test_gradient_clipping():
...
@@ -565,3 +565,101 @@ def test_gradient_clipping():
mp
.
spawn
(
mp
.
spawn
(
run_gradient_clipping
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
run_gradient_clipping
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
)
def
run_state_dict_distributed
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
"gloo"
)
device
=
torch
.
device
(
rank
)
torch
.
manual_seed
(
rank
)
# make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
20
,
10
,
5
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
model_oss1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
),
).
to
(
device
)
model_oss2
=
copy
.
deepcopy
(
model_oss1
)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
# gradient norm computation from OSS and adds a dependency.
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss1
=
DDP
(
module
=
model_oss1
,
device_ids
=
[
rank
],)
sharded_optimizer1
=
optim
.
OSS
(
model_oss1
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
model_oss2
=
DDP
(
module
=
model_oss2
,
device_ids
=
[
rank
],)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
def
run_grad_step
(
device
,
model
,
optimizer
):
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
model
.
zero_grad
()
outputs
=
model
(
inputs
)
loss
=
loss_fn
(
outputs
,
target
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
zero_grad
()
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before saving)"
# save the state dict for one model only
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
# Check that the pulled state and the .param_groups attribute are in sync
for
replica
in
range
(
len
(
state_dict2
[
"param_groups"
])):
for
k
in
state_dict2
[
"param_groups"
][
replica
].
keys
():
if
k
!=
"params"
:
assert
state_dict2
[
"param_groups"
][
replica
][
k
]
==
sharded_optimizer2
.
param_groups
[
0
][
k
]
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that saving did not cause a change in the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (after consolidating)"
# save again
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
# reload the state_dict
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that reloading a saved state dict does not change the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (after reloading)"
dist
.
destroy_process_group
()
@
skip_if_no_cuda
def
test_state_dict_distributed
():
world_size
=
8
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_state_dict_distributed
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment