Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
02405740
Unverified
Commit
02405740
authored
Mar 08, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 08, 2021
Browse files
[fix] oss and interleaved param groups (#483)
parent
64bbb6e1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
43 deletions
+58
-43
fairscale/optim/oss.py
fairscale/optim/oss.py
+25
-19
tests/optim/test_oss.py
tests/optim/test_oss.py
+33
-24
No files found.
fairscale/optim/oss.py
View file @
02405740
...
@@ -85,7 +85,6 @@ class OSS(Optimizer):
...
@@ -85,7 +85,6 @@ class OSS(Optimizer):
self
.
_per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
_local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
@@ -160,15 +159,6 @@ class OSS(Optimizer):
...
@@ -160,15 +159,6 @@ class OSS(Optimizer):
# Make sure that the iterator is not consumed, only expose a copy
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
_local_params
return
self
.
_local_params
@
property
def
index_to_param
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if
len
(
self
.
_index_to_param
)
==
0
:
self
.
_index_to_param
=
{
i
:
p
for
i
,
p
in
enumerate
(
chain
(
*
(
g
[
"params"
]
for
g
in
self
.
param_groups
)))}
return
self
.
_index_to_param
@
property
@
property
def
param_to_index
(
self
)
->
Dict
[
int
,
int
]:
def
param_to_index
(
self
)
->
Dict
[
int
,
int
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
...
@@ -376,7 +366,7 @@ class OSS(Optimizer):
...
@@ -376,7 +366,7 @@ class OSS(Optimizer):
global_id
=
self
.
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
global_id
=
self
.
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
# Make sure that the parameters are sorted in the state, as expected
# Make sure that the parameters are sorted in the state, as expected
for a pytorch dict
state_dict
[
"state"
]
=
dict
(
sorted
(
state_dict
[
"state"
].
items
()))
state_dict
[
"state"
]
=
dict
(
sorted
(
state_dict
[
"state"
].
items
()))
return
state_dict
return
state_dict
...
@@ -389,17 +379,34 @@ class OSS(Optimizer):
...
@@ -389,17 +379,34 @@ class OSS(Optimizer):
from a call to :meth:`state_dict`
from a call to :meth:`state_dict`
"""
"""
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# Update the state, trusting the ordering in param_groups
# we work around that here by using the fact that the params are ordered as in the param_groups
# Apart from the removal of states not owned by this rank, the pytorch logic is kept
pytorch15_index_redirect
=
{
k
:
i
for
i
,
k
in
enumerate
(
state_dict
[
"state"
].
keys
())}
# (See torch.optim.optimizer)
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
state_dict
[
"param_groups"
])),
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
self
.
param_groups
)),
)
}
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list
=
list
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
self
.
param_groups
)))
for
key
,
value
in
state_dict
[
"state"
].
items
():
for
key
,
value
in
state_dict
[
"state"
].
items
():
param
=
self
.
index_to_param
[
pytorch15_index_redirect
[
key
]]
if
key
in
id_map
:
param
=
id_map
[
key
]
# Populate the sharded optimizer state on the fly
# Populate the sharded optimizer state on the fly,
if
self
.
param_to_rank
[
param
]
!=
self
.
rank
:
# remove the params that this rank does not own
state_dict
[
"state"
][
key
]
=
None
if
self
.
param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
None
else
:
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
else
:
else
:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
print
(
key
,
"not in idmap"
)
param
=
_param_list
[
key
]
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
super
().
load_state_dict
(
state_dict
)
super
().
load_state_dict
(
state_dict
)
...
@@ -515,7 +522,6 @@ class OSS(Optimizer):
...
@@ -515,7 +522,6 @@ class OSS(Optimizer):
self
.
_partition_parameters
.
clear
()
self
.
_partition_parameters
.
clear
()
self
.
_per_device_params
.
clear
()
self
.
_per_device_params
.
clear
()
self
.
_param_rank
.
clear
()
self
.
_param_rank
.
clear
()
self
.
_index_to_param
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
_local_params
=
None
self
.
_local_params
=
None
...
...
tests/optim/test_oss.py
View file @
02405740
...
@@ -22,7 +22,13 @@ import torch.multiprocessing as mp
...
@@ -22,7 +22,13 @@ import torch.multiprocessing as mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
fairscale.optim
as
optim
import
fairscale.optim
as
optim
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_py39_no_cuda
,
skip_if_single_gpu
from
fairscale.utils.testing
import
(
check_same_model_params
,
skip_if_no_cuda
,
skip_if_py39_no_cuda
,
skip_if_single_gpu
,
torch_version
,
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
...
@@ -811,9 +817,11 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
...
@@ -811,9 +817,11 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# Define a model to be trained by OSS
# Define a model to be trained by OSS
oss_module
=
torch
.
nn
.
Sequential
(
trunk
,
head
)
oss_module
=
torch
.
nn
.
Sequential
(
trunk
,
head
)
# Make sure that the param groups are interleaved, to catch an ordering bug in the state dict
oss_trainable_params
=
[
oss_trainable_params
=
[
{
"params"
:
trunk
.
parameters
(),
"lr"
:
1e-5
},
{
"params"
:
list
(
trunk
.
parameters
()
)[:
-
1
]
+
list
(
head
.
parameters
())
,
"lr"
:
1e-5
},
{
"params"
:
head
.
parameters
(),
"lr"
:
1e-4
},
{
"params"
:
list
(
trunk
.
parameters
()
)[
-
1
]
,
"lr"
:
1e-4
},
]
]
optimizer_settings
:
Dict
[
Any
,
Any
]
=
{}
optimizer_settings
:
Dict
[
Any
,
Any
]
=
{}
...
@@ -836,8 +844,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
...
@@ -836,8 +844,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
ddp_module
=
torch
.
nn
.
Sequential
(
ddp_trunk
,
ddp_head
)
ddp_module
=
torch
.
nn
.
Sequential
(
ddp_trunk
,
ddp_head
)
ddp_trainable_params
=
[
ddp_trainable_params
=
[
{
"params"
:
ddp_trunk
.
parameters
(),
"lr"
:
1e-5
},
{
"params"
:
list
(
ddp_trunk
.
parameters
()
)[:
-
1
]
+
list
(
ddp_head
.
parameters
())
,
"lr"
:
1e-5
},
{
"params"
:
ddp_head
.
parameters
(),
"lr"
:
1e-4
},
{
"params"
:
list
(
ddp_trunk
.
parameters
()
)[
-
1
]
,
"lr"
:
1e-4
},
]
]
ddp_optimizer
=
optimizer
(
ddp_trainable_params
,
**
optimizer_settings
)
# type: ignore
ddp_optimizer
=
optimizer
(
ddp_trainable_params
,
**
optimizer_settings
)
# type: ignore
ddp_model
=
DDP
(
module
=
ddp_module
,
device_ids
=
[
rank
],
broadcast_buffers
=
True
,
find_unused_parameters
=
True
)
ddp_model
=
DDP
(
module
=
ddp_module
,
device_ids
=
[
rank
],
broadcast_buffers
=
True
,
find_unused_parameters
=
True
)
...
@@ -880,25 +888,26 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
...
@@ -880,25 +888,26 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
next
(
oss_module
.
parameters
()).
requires_grad
=
not
next
(
oss_module
.
parameters
()).
requires_grad
next
(
oss_module
.
parameters
()).
requires_grad
=
not
next
(
oss_module
.
parameters
()).
requires_grad
# sharded_optimizer.refresh_trainable()
# sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible
# Check that the checkpoints are compatible (post pytorch 1.5)
# - get states
if
torch_version
()[
1
]
>
5
:
ddp_state_dict
=
ddp_optimizer
.
state_dict
()
# - get states
sharded_optimizer
.
consolidate_state_dict
(
recipient_rank
=
RECIPIENT_RANK
)
ddp_state_dict
=
ddp_optimizer
.
state_dict
()
sharded_optim_state_dict
=
sharded_optimizer
.
state_dict
()
if
rank
==
RECIPIENT_RANK
else
{}
sharded_optimizer
.
consolidate_state_dict
(
recipient_rank
=
RECIPIENT_RANK
)
sharded_optim_state_dict
=
sync_object_ranks
(
sharded_optim_state_dict
,
RECIPIENT_RANK
,
device
)
sharded_optim_state_dict
=
sharded_optimizer
.
state_dict
()
if
rank
==
RECIPIENT_RANK
else
{}
sharded_optim_state_dict
=
sync_object_ranks
(
sharded_optim_state_dict
,
RECIPIENT_RANK
,
device
)
# - cross load the states
# run one step and check that the models are still the same
# - cross load the states
ddp_state_dict_ref
=
copy
.
deepcopy
(
ddp_state_dict
)
# OSS will remove some states
# run one step and check that the models are still the same
ddp_optimizer
.
load_state_dict
(
sharded_optim_state_dict
)
# mixup on purpose !
ddp_state_dict_ref
=
copy
.
deepcopy
(
ddp_state_dict
)
# OSS will remove some states
sharded_optimizer
.
load_state_dict
(
ddp_state_dict
)
ddp_optimizer
.
load_state_dict
(
sharded_optim_state_dict
)
# mixup on purpose !
check_step
()
sharded_optimizer
.
load_state_dict
(
ddp_state_dict
)
check_step
()
# - self load, rewind, check no problem
# run one step and check that the models are still the same
# - self load, rewind, check no problem
ddp_optimizer
.
load_state_dict
(
ddp_state_dict_ref
)
# run one step and check that the models are still the same
sharded_optimizer
.
load_state_dict
(
sharded_optim_state_dict
)
ddp_optimizer
.
load_state_dict
(
ddp_state_dict_ref
)
check_step
()
sharded_optimizer
.
load_state_dict
(
sharded_optim_state_dict
)
check_step
()
for
opt
in
[
torch
.
optim
.
Adam
,
torch
.
optim
.
SGD
]:
for
opt
in
[
torch
.
optim
.
Adam
,
torch
.
optim
.
SGD
]:
check_optimizer_equivalence
(
opt
,
change_train_graph
=
False
)
check_optimizer_equivalence
(
opt
,
change_train_graph
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment