Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
220ee323
Unverified
Commit
220ee323
authored
Aug 27, 2020
by
msbaines
Committed by
GitHub
Aug 27, 2020
Browse files
[fix] optim/oss: PyTorch already handles putting state on proper device (#54)
parent
09028a0d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
6 deletions
+11
-6
fairscale/optim/oss.py
fairscale/optim/oss.py
+1
-4
tests/optim/test_oss.py
tests/optim/test_oss.py
+10
-2
No files found.
fairscale/optim/oss.py
View file @
220ee323
...
@@ -139,10 +139,7 @@ class OSS(Optimizer):
...
@@ -139,10 +139,7 @@ class OSS(Optimizer):
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
""" Loads this rank's state_dict. """
""" Loads this rank's state_dict. """
# Make sure that the state is on the appropriate device
self
.
optim
.
load_state_dict
(
state_dict
)
state_dict_ondevice
=
recursive_copy_to_device
(
state_dict
,
non_blocking
=
False
,
device
=
self
.
_device
)
self
.
optim
.
load_state_dict
(
state_dict_ondevice
)
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
""" Restore the global parameter groups as well as the shard """
""" Restore the global parameter groups as well as the shard """
...
...
tests/optim/test_oss.py
View file @
220ee323
...
@@ -41,7 +41,12 @@ def test_create():
...
@@ -41,7 +41,12 @@ def test_create():
def
test_state_dict
():
def
test_state_dict
():
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
,
momentum
=
0.9
)
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
o
.
zero_grad
()
o
.
consolidate_state_dict
()
# Sync state dict in between replicas - even if there are none
o
.
consolidate_state_dict
()
# Sync state dict in between replicas - even if there are none
state_dict
=
o
.
state_dict
()
state_dict
=
o
.
state_dict
()
...
@@ -54,13 +59,16 @@ def test_state_dict():
...
@@ -54,13 +59,16 @@ def test_state_dict():
# Check that it's correctly loaded
# Check that it's correctly loaded
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_state_dict
(
state_dict
)
o
.
load_state_dict
(
state_dict
)
# Check that state is correct and on proper device
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.0
],
device
=
DEVICE
)
# We should now be using a lr of 0.1, both within the optimizer
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
# and as exposed by the .param_groups attribute
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
assert
o
.
param_groups
[
0
][
"lr"
]
==
0.1
x
.
backward
()
x
.
backward
()
o
.
step
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
x
==
torch
.
tensor
([
0.71
],
device
=
DEVICE
)
assert
o
.
optim
.
state
[
x
][
"momentum_buffer"
]
==
torch
.
tensor
([
1.9
],
device
=
DEVICE
)
# Check that the exposed param_groups are on the proper device
# Check that the exposed param_groups are on the proper device
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
assert
o
.
param_groups
[
0
][
"params"
][
0
].
device
==
x
.
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment