Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
02405740
Unverified
Commit
02405740
authored
Mar 08, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 08, 2021
Browse files
[fix] oss and interleaved param groups (#483)
parent
64bbb6e1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
43 deletions
+58
-43
fairscale/optim/oss.py
fairscale/optim/oss.py
+25
-19
tests/optim/test_oss.py
tests/optim/test_oss.py
+33
-24
No files found.
fairscale/optim/oss.py
View file @
02405740
...
...
@@ -85,7 +85,6 @@ class OSS(Optimizer):
self
.
_per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
...
@@ -160,15 +159,6 @@ class OSS(Optimizer):
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
_local_params
@
property
def
index_to_param
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if
len
(
self
.
_index_to_param
)
==
0
:
self
.
_index_to_param
=
{
i
:
p
for
i
,
p
in
enumerate
(
chain
(
*
(
g
[
"params"
]
for
g
in
self
.
param_groups
)))}
return
self
.
_index_to_param
@
property
def
param_to_index
(
self
)
->
Dict
[
int
,
int
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
...
...
@@ -376,7 +366,7 @@ class OSS(Optimizer):
global_id
=
self
.
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
# Make sure that the parameters are sorted in the state, as expected
# Make sure that the parameters are sorted in the state, as expected
for a pytorch dict
state_dict
[
"state"
]
=
dict
(
sorted
(
state_dict
[
"state"
].
items
()))
return
state_dict
...
...
@@ -389,18 +379,35 @@ class OSS(Optimizer):
from a call to :meth:`state_dict`
"""
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect
=
{
k
:
i
for
i
,
k
in
enumerate
(
state_dict
[
"state"
].
keys
())}
# Update the state, trusting the ordering in param_groups
# Apart from the removal of states not owned by this rank, the pytorch logic is kept
# (See torch.optim.optimizer)
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
state_dict
[
"param_groups"
])),
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
self
.
param_groups
)),
)
}
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list
=
list
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
self
.
param_groups
)))
for
key
,
value
in
state_dict
[
"state"
].
items
():
param
=
self
.
index_to_param
[
pytorch15_index_redirect
[
key
]]
if
key
in
id_map
:
param
=
id_map
[
key
]
# Populate the sharded optimizer state on the fly
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
if
self
.
param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
None
else
:
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
else
:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
print
(
key
,
"not in idmap"
)
param
=
_param_list
[
key
]
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
super
().
load_state_dict
(
state_dict
)
...
...
@@ -515,7 +522,6 @@ class OSS(Optimizer):
self
.
_partition_parameters
.
clear
()
self
.
_per_device_params
.
clear
()
self
.
_param_rank
.
clear
()
self
.
_index_to_param
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
_local_params
=
None
...
...
tests/optim/test_oss.py
View file @
02405740
...
...
@@ -22,7 +22,13 @@ import torch.multiprocessing as mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
fairscale.optim
as
optim
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_py39_no_cuda
,
skip_if_single_gpu
from
fairscale.utils.testing
import
(
check_same_model_params
,
skip_if_no_cuda
,
skip_if_py39_no_cuda
,
skip_if_single_gpu
,
torch_version
,
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
...
...
@@ -811,9 +817,11 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# Define a model to be trained by OSS
oss_module
=
torch
.
nn
.
Sequential
(
trunk
,
head
)
# Make sure that the param groups are interleaved, to catch an ordering bug in the state dict
oss_trainable_params
=
[
{
"params"
:
trunk
.
parameters
(),
"lr"
:
1e-5
},
{
"params"
:
head
.
parameters
(),
"lr"
:
1e-4
},
{
"params"
:
list
(
trunk
.
parameters
()
)[:
-
1
]
+
list
(
head
.
parameters
())
,
"lr"
:
1e-5
},
{
"params"
:
list
(
trunk
.
parameters
()
)[
-
1
]
,
"lr"
:
1e-4
},
]
optimizer_settings
:
Dict
[
Any
,
Any
]
=
{}
...
...
@@ -836,8 +844,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
ddp_module
=
torch
.
nn
.
Sequential
(
ddp_trunk
,
ddp_head
)
ddp_trainable_params
=
[
{
"params"
:
ddp_trunk
.
parameters
(),
"lr"
:
1e-5
},
{
"params"
:
ddp_head
.
parameters
(),
"lr"
:
1e-4
},
{
"params"
:
list
(
ddp_trunk
.
parameters
()
)[:
-
1
]
+
list
(
ddp_head
.
parameters
())
,
"lr"
:
1e-5
},
{
"params"
:
list
(
ddp_trunk
.
parameters
()
)[
-
1
]
,
"lr"
:
1e-4
},
]
ddp_optimizer
=
optimizer
(
ddp_trainable_params
,
**
optimizer_settings
)
# type: ignore
ddp_model
=
DDP
(
module
=
ddp_module
,
device_ids
=
[
rank
],
broadcast_buffers
=
True
,
find_unused_parameters
=
True
)
...
...
@@ -880,7 +888,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
next
(
oss_module
.
parameters
()).
requires_grad
=
not
next
(
oss_module
.
parameters
()).
requires_grad
# sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible
# Check that the checkpoints are compatible (post pytorch 1.5)
if
torch_version
()[
1
]
>
5
:
# - get states
ddp_state_dict
=
ddp_optimizer
.
state_dict
()
sharded_optimizer
.
consolidate_state_dict
(
recipient_rank
=
RECIPIENT_RANK
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment