Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b0e6b9bd
Unverified
Commit
b0e6b9bd
authored
Apr 20, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 20, 2021
Browse files
[chore] OSS to 100% coverage (#618)
parent
d9f36130
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
28 deletions
+6
-28
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-28
No files found.
fairscale/optim/oss.py
View file @
b0e6b9bd
...
...
@@ -301,20 +301,6 @@ class OSS(Optimizer):
logging
.
debug
(
"State from rank %s received"
,
rank
)
def
local_state_dict
(
self
)
->
dict
:
""" .. deprecated:: 0.1.5
Returns this rank's state_dict as a :class:`dict` which contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning: This does not represent the optimizer state dict, only a shard.
"""
return
self
.
optim
.
state_dict
()
def
state_dict
(
self
,
all_ranks
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed.
...
...
@@ -390,22 +376,14 @@ class OSS(Optimizer):
)
}
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list
=
list
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
self
.
param_groups
)))
for
key
,
value
in
state_dict
[
"state"
].
items
():
if
key
in
id_map
:
param
=
id_map
[
key
]
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
if
self
.
_param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
{}
else
:
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
param
=
id_map
[
key
]
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
if
self
.
_param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
{}
else
:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
param
=
_param_list
[
key
]
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
super
().
load_state_dict
(
state_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment