Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8778fa66
Unverified
Commit
8778fa66
authored
Feb 05, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 05, 2021
Browse files
[fix] repro+fix (#365)
fix a broken earlier commit, only worked for the first step
parent
4dc605c9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
fairscale/optim/oss.py
fairscale/optim/oss.py
+14
-8
tests/optim/test_oss.py
tests/optim/test_oss.py
+3
-0
No files found.
fairscale/optim/oss.py
View file @
8778fa66
...
@@ -8,7 +8,7 @@ import copy
...
@@ -8,7 +8,7 @@ import copy
from
itertools
import
chain
from
itertools
import
chain
import
logging
import
logging
from
math
import
inf
from
math
import
inf
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Deque
,
Dict
,
List
,
Optional
,
Type
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -81,7 +81,7 @@ class OSS(Optimizer):
...
@@ -81,7 +81,7 @@ class OSS(Optimizer):
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
Iterable
[
Any
]]
=
None
self
.
_local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# Build the wrapped optimizer, responsible for a shard of the params
# Build the wrapped optimizer, responsible for a shard of the params
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
...
@@ -145,14 +145,20 @@ class OSS(Optimizer):
...
@@ -145,14 +145,20 @@ class OSS(Optimizer):
return
self
.
_partition_parameters
return
self
.
_partition_parameters
@
property
@
property
def
local_params
(
self
)
->
Iterable
[
torch
.
Tensor
]:
def
local_params
(
self
)
->
List
[
torch
.
Tensor
]:
""" Iterable which goes through the parameters that this rank owns
"""
if
self
.
_local_params
is
None
:
if
self
.
_local_params
is
None
:
self
.
_local_params
=
chain
(
self
.
_local_params
=
list
(
chain
(
*
[
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
for
device_params
in
self
.
per_device_params
.
values
()
]
]
)
)
)
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
_local_params
return
self
.
_local_params
@
property
@
property
...
...
tests/optim/test_oss.py
View file @
8778fa66
...
@@ -632,6 +632,9 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
...
@@ -632,6 +632,9 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
print
(
f
"Checking norm
{
norm
}
"
)
print
(
f
"Checking norm
{
norm
}
"
)
check
(
norm
)
check
(
norm
)
# Check twice, catch an hypothetic iterator dumb mistake
check
(
norm
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment