Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2eee136f
Unverified
Commit
2eee136f
authored
Oct 01, 2020
by
msbaines
Committed by
GitHub
Oct 01, 2020
Browse files
[fix] re-run black to fix CPU tests on master (#123)
parent
379c6bf0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-5
No files found.
fairscale/optim/oss.py
View file @
2eee136f
...
@@ -137,7 +137,7 @@ class OSS(Optimizer):
...
@@ -137,7 +137,7 @@ class OSS(Optimizer):
@
property
@
property
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
'''
param to data parallel rank
'''
"""
param to data parallel rank
"""
if
len
(
self
.
_param_rank
)
==
0
:
if
len
(
self
.
_param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
for
param_group
in
param_groups
:
...
@@ -145,11 +145,11 @@ class OSS(Optimizer):
...
@@ -145,11 +145,11 @@ class OSS(Optimizer):
self
.
_param_rank
[
param
]
=
rank
self
.
_param_rank
[
param
]
=
rank
return
self
.
_param_rank
return
self
.
_param_rank
def
get_global_rank
(
self
,
group
,
rank
)
:
def
get_global_rank
(
self
,
group
:
Any
,
rank
:
int
)
->
int
:
if
group
is
dist
.
group
.
WORLD
:
if
group
is
dist
.
group
.
WORLD
:
return
rank
return
rank
else
:
else
:
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# type: ignore
return
global_rank
return
global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...
@@ -376,7 +376,9 @@ class OSS(Optimizer):
...
@@ -376,7 +376,9 @@ class OSS(Optimizer):
logging
.
debug
(
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing
# Discard this tensor/rank, broadcast necessary for syncing
...
@@ -393,4 +395,3 @@ class OSS(Optimizer):
...
@@ -393,4 +395,3 @@ class OSS(Optimizer):
for
p
in
partition
:
for
p
in
partition
:
for
t
in
p
[
"params"
]:
for
t
in
p
[
"params"
]:
t
.
grad
=
None
t
.
grad
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment