Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d217278c
Unverified
Commit
d217278c
authored
Mar 16, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 16, 2021
Browse files
[feat][OSS] handle the device being changed after construction (#523)
parent
2d2412e2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
0 deletions
+24
-0
fairscale/optim/oss.py
fairscale/optim/oss.py
+12
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+12
-0
No files found.
fairscale/optim/oss.py
View file @
d217278c
...
@@ -220,6 +220,12 @@ class OSS(Optimizer):
...
@@ -220,6 +220,12 @@ class OSS(Optimizer):
# Sync oss param_groups attributes in case they've been updated by a scheduler.
# Sync oss param_groups attributes in case they've been updated by a scheduler.
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
# Catch a possible change of devices in between OSS construction and step()
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
self
.
_clear_cache
()
self
.
refresh_trainable
()
# Run the optimizer step on this shard only:
# Run the optimizer step on this shard only:
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
...
@@ -591,3 +597,9 @@ class OSS(Optimizer):
...
@@ -591,3 +597,9 @@ class OSS(Optimizer):
else
:
else
:
# This rank has an empty shard, that's fine
# This rank has an empty shard, that's fine
self
.
buckets
[
device
].
append
(
torch
.
zeros
(
0
,
device
=
device
))
self
.
buckets
[
device
].
append
(
torch
.
zeros
(
0
,
device
=
device
))
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
per_device_params
.
keys
())
devices_to_pop
=
list
(
filter
(
lambda
x
:
x
not
in
devices_in_use
,
self
.
buckets
.
keys
()))
for
d
in
devices_to_pop
:
self
.
buckets
.
pop
(
d
)
tests/optim/test_oss.py
View file @
d217278c
...
@@ -154,6 +154,18 @@ class TestSingleRank(unittest.TestCase):
...
@@ -154,6 +154,18 @@ class TestSingleRank(unittest.TestCase):
assert
kwarg
==
[
5
]
assert
kwarg
==
[
5
]
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
@
skip_if_no_cuda
def
test_device_change
(
self
):
x
=
torch
.
nn
.
Linear
(
1
,
1
).
to
(
"cpu"
)
o
=
optim
.
OSS
(
x
.
parameters
(),
torch
.
optim
.
SGD
,
lr
=
0.1
)
# Move the model to device after OSS was constructed
x
.
to
(
DEVICE
)
x
(
torch
.
zeros
((
1
),
device
=
DEVICE
)).
backward
()
# Check that OSS detects that the device changed
o
.
step
()
def
test_step_with_extra_inner_key
(
self
):
def
test_step_with_extra_inner_key
(
self
):
class
SGDWithNewKey
(
torch
.
optim
.
SGD
):
class
SGDWithNewKey
(
torch
.
optim
.
SGD
):
# Dummy optimizer which adds a new key to the param groups
# Dummy optimizer which adds a new key to the param groups
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment