Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e4a0804c
Unverified
Commit
e4a0804c
authored
Aug 27, 2020
by
msbaines
Committed by
GitHub
Aug 27, 2020
Browse files
[refactor] optim/oss: save memory and time by avoiding duplicate copy of parameters (#57)
parent
220ee323
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
13 deletions
+12
-13
fairscale/optim/oss.py
fairscale/optim/oss.py
+9
-7
tests/optim/test_oss.py
tests/optim/test_oss.py
+3
-6
No files found.
fairscale/optim/oss.py
View file @
e4a0804c
...
...
@@ -85,10 +85,9 @@ class OSS(Optimizer):
param_lists
[
rank
].
append
(
param
)
sizes
[
rank
]
+=
param
.
numel
()
for
rank
,
params
in
enumerate
(
param_lists
):
if
len
(
params
)
>
0
:
param_group_rank
=
copy
.
copy
(
param_group
)
param_group_rank
[
"params"
]
=
params
param_groups
[
rank
].
append
(
param_group_rank
)
param_group_rank
=
copy
.
copy
(
param_group
)
param_group_rank
[
"params"
]
=
params
param_groups
[
rank
].
append
(
param_group_rank
)
return
param_groups
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...
...
@@ -134,7 +133,7 @@ class OSS(Optimizer):
len
(
self
.
_all_states
)
>
0
),
"The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return
{
"state"
:
self
.
_all_states
,
"param_groups"
:
self
.
param_groups
}
return
{
"state"
:
self
.
_all_states
}
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
""" Loads this rank's state_dict. """
...
...
@@ -146,8 +145,11 @@ class OSS(Optimizer):
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self
.
load_local_state_dict
(
state_dict
[
"state"
][
self
.
rank
])
# Restore the global param_groups
self
.
param_groups
=
recursive_copy_to_device
(
state_dict
[
"param_groups"
],
non_blocking
=
True
,
device
=
self
.
_device
)
# Restore the global param_groups (the params themselves are already correct)
for
global_group
,
local_group
in
zip
(
self
.
param_groups
,
self
.
optim
.
param_groups
):
for
k
,
v
in
local_group
.
items
():
if
k
!=
"params"
:
global_group
[
k
]
=
v
def
add_param_group
(
self
,
param_group
:
dict
)
->
None
:
super
().
add_param_group
(
param_group
)
...
...
tests/optim/test_oss.py
View file @
e4a0804c
...
...
@@ -51,10 +51,10 @@ def test_state_dict():
state_dict
=
o
.
state_dict
()
# Check that the pulled state is what we expect
assert
state_dict
[
"param_groups"
][
0
][
"lr"
]
==
0.1
assert
state_dict
[
"
state"
][
0
][
"
param_groups"
][
0
][
"lr"
]
==
0.1
# Check that the pulled state and the .param_groups attribute are in sync
assert
state_dict
[
"param_groups"
][
0
][
"lr"
]
==
o
.
param_groups
[
0
][
"lr"
]
assert
state_dict
[
"
state"
][
0
][
"
param_groups"
][
0
][
"lr"
]
==
o
.
param_groups
[
0
][
"lr"
]
# Check that it's correctly loaded
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
...
...
@@ -113,10 +113,7 @@ def run_test_add_param_group(rank, world_size):
assert
len
(
o
.
param_groups
)
==
2
# Verify that added group is added to the correct partition making all have 8 elements.
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
if
rank
==
1
:
assert
len
(
o
.
optim
.
param_groups
)
==
2
else
:
assert
len
(
o
.
optim
.
param_groups
)
==
1
assert
len
(
o
.
optim
.
param_groups
)
==
2
def
test_add_param_group
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment