Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
60c8de4a
Unverified
Commit
60c8de4a
authored
Dec 28, 2020
by
Joshua Meier
Committed by
GitHub
Dec 28, 2020
Browse files
[feature] OSS: add unit test for distributed checkpointing (#273)
author: Joshua Meier
parent
b640cab5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
0 deletions
+98
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+98
-0
No files found.
tests/optim/test_oss.py
View file @
60c8de4a
...
@@ -565,3 +565,101 @@ def test_gradient_clipping():
...
@@ -565,3 +565,101 @@ def test_gradient_clipping():
mp
.
spawn
(
mp
.
spawn
(
run_gradient_clipping
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
run_gradient_clipping
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
)
def
run_state_dict_distributed
(
rank
,
world_size
,
tempfile_name
):
dist_init
(
rank
,
world_size
,
tempfile_name
,
backend
=
"gloo"
)
device
=
torch
.
device
(
rank
)
torch
.
manual_seed
(
rank
)
# make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
20
,
10
,
5
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
model_oss1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
),
).
to
(
device
)
model_oss2
=
copy
.
deepcopy
(
model_oss1
)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
# gradient norm computation from OSS and adds a dependency.
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss1
=
DDP
(
module
=
model_oss1
,
device_ids
=
[
rank
],)
sharded_optimizer1
=
optim
.
OSS
(
model_oss1
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
model_oss2
=
DDP
(
module
=
model_oss2
,
device_ids
=
[
rank
],)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
def
run_grad_step
(
device
,
model
,
optimizer
):
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
model
.
zero_grad
()
outputs
=
model
(
inputs
)
loss
=
loss_fn
(
outputs
,
target
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
zero_grad
()
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before saving)"
# save the state dict for one model only
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
# Check that the pulled state and the .param_groups attribute are in sync
for
replica
in
range
(
len
(
state_dict2
[
"param_groups"
])):
for
k
in
state_dict2
[
"param_groups"
][
replica
].
keys
():
if
k
!=
"params"
:
assert
state_dict2
[
"param_groups"
][
replica
][
k
]
==
sharded_optimizer2
.
param_groups
[
0
][
k
]
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that saving did not cause a change in the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (after consolidating)"
# save again
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
# reload the state_dict
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that reloading a saved state dict does not change the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (after reloading)"
dist
.
destroy_process_group
()
@
skip_if_no_cuda
def
test_state_dict_distributed
():
world_size
=
8
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
if
torch
.
cuda
.
is_available
():
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
mp
.
spawn
(
run_state_dict_distributed
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment