Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b0e6b9bd
Unverified
Commit
b0e6b9bd
authored
Apr 20, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 20, 2021
Browse files
[chore] OSS to 100% coverage (#618)
parent
d9f36130
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
28 deletions
+6
-28
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-28
No files found.
fairscale/optim/oss.py
View file @
b0e6b9bd
...
@@ -301,20 +301,6 @@ class OSS(Optimizer):
...
@@ -301,20 +301,6 @@ class OSS(Optimizer):
logging
.
debug
(
"State from rank %s received"
,
rank
)
logging
.
debug
(
"State from rank %s received"
,
rank
)
def
local_state_dict
(
self
)
->
dict
:
""" .. deprecated:: 0.1.5
Returns this rank's state_dict as a :class:`dict` which contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning: This does not represent the optimizer state dict, only a shard.
"""
return
self
.
optim
.
state_dict
()
def
state_dict
(
self
,
all_ranks
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
def
state_dict
(
self
,
all_ranks
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed.
sharded properties are not exposed.
...
@@ -390,22 +376,14 @@ class OSS(Optimizer):
...
@@ -390,22 +376,14 @@ class OSS(Optimizer):
)
)
}
}
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list
=
list
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
self
.
param_groups
)))
for
key
,
value
in
state_dict
[
"state"
].
items
():
for
key
,
value
in
state_dict
[
"state"
].
items
():
if
key
in
id_map
:
param
=
id_map
[
key
]
param
=
id_map
[
key
]
# Populate the sharded optimizer state on the fly,
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
# remove the params that this rank does not own
if
self
.
_param_to_rank
[
param
]
!=
self
.
rank
:
if
self
.
_param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
{}
state_dict
[
"state"
][
key
]
=
{}
else
:
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
else
:
else
:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
param
=
_param_list
[
key
]
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
super
().
load_state_dict
(
state_dict
)
super
().
load_state_dict
(
state_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment