Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
220ee323
Unverified
Commit
220ee323
authored
Aug 27, 2020
by
msbaines
Committed by
GitHub
Aug 27, 2020
Browse files
[fix] optim/oss: PyTorch already handles putting state on proper device (#54)
parent
09028a0d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
6 deletions
+11
-6
fairscale/optim/oss.py
fairscale/optim/oss.py
+1
-4
tests/optim/test_oss.py
tests/optim/test_oss.py
+10
-2
No files found.
fairscale/optim/oss.py
View file @
220ee323
...
...
@@ -139,10 +139,7 @@ class OSS(Optimizer):
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
""" Loads this rank's state_dict. """
# Make sure that the state is on the appropriate device
state_dict_ondevice
=
recursive_copy_to_device
(
state_dict
,
non_blocking
=
False
,
device
=
self
.
_device
)
self
.
optim
.
load_state_dict
(
state_dict_ondevice
)
self
.
optim
.
load_state_dict
(
state_dict
)
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
""" Restore the global parameter groups as well as the shard """
...
...
tests/optim/test_oss.py
View file @
220ee323
...
...
@@ -41,7 +41,12 @@ def test_create():
def
test_state_dict
():
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
,
momentum
=
0.9
)
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
o
.
zero_grad
()
o
.
consolidate_state_dict
()
# Sync state dict in between replicas - even if there are none
state_dict
=
o
.
state_dict
()
...
...
@@ -54,13 +59,16 @@ def test_state_dict():
# Check that it's correctly loaded
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_state_dict
(
state_dict
)
# Check that state is correct and on proper device
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
x
==
torch
.
tensor
([
0.71
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.9
],
device
=
DEVICE
)
# Check that the exposed param_groups are on the proper device
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment