Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2eee136f
"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "83ea9a8d58dc8bd7745663e5161672bfed849da1"
Unverified
Commit
2eee136f
authored
Oct 01, 2020
by
msbaines
Committed by
GitHub
Oct 01, 2020
Browse files
[fix] re-run black to fix CPU tests on master (#123)
parent
379c6bf0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-5
No files found.
fairscale/optim/oss.py
View file @
2eee136f
...
@@ -137,7 +137,7 @@ class OSS(Optimizer):
...
@@ -137,7 +137,7 @@ class OSS(Optimizer):
@
property
@
property
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
'''
param to data parallel rank
'''
"""
param to data parallel rank
"""
if
len
(
self
.
_param_rank
)
==
0
:
if
len
(
self
.
_param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
for
param_group
in
param_groups
:
...
@@ -145,11 +145,11 @@ class OSS(Optimizer):
...
@@ -145,11 +145,11 @@ class OSS(Optimizer):
self
.
_param_rank
[
param
]
=
rank
self
.
_param_rank
[
param
]
=
rank
return
self
.
_param_rank
return
self
.
_param_rank
def
get_global_rank
(
self
,
group
,
rank
)
:
def
get_global_rank
(
self
,
group
:
Any
,
rank
:
int
)
->
int
:
if
group
is
dist
.
group
.
WORLD
:
if
group
is
dist
.
group
.
WORLD
:
return
rank
return
rank
else
:
else
:
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# type: ignore
return
global_rank
return
global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...
@@ -376,7 +376,9 @@ class OSS(Optimizer):
...
@@ -376,7 +376,9 @@ class OSS(Optimizer):
logging
.
debug
(
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing
# Discard this tensor/rank, broadcast necessary for syncing
...
@@ -393,4 +395,3 @@ class OSS(Optimizer):
...
@@ -393,4 +395,3 @@ class OSS(Optimizer):
for
p
in
partition
:
for
p
in
partition
:
for
t
in
p
[
"params"
]:
for
t
in
p
[
"params"
]:
t
.
grad
=
None
t
.
grad
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment