Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
38ad8638
Unverified
Commit
38ad8638
authored
Jan 27, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 27, 2021
Browse files
[fix] OSS: removing the torch broadcast util altogether, broken on 1.7.1 (#329)
* removing the torch util altogether, broken on 1.7.1
parent
f5ab9a18
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
50 deletions
+23
-50
fairscale/optim/oss.py
fairscale/optim/oss.py
+23
-50
No files found.
fairscale/optim/oss.py
View file @
38ad8638
...
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
.utils
import
Workhandle
,
recursive_copy_to_device
from
.utils
import
Workhandle
,
broadcast_object
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
...
...
@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
else
:
_params_t
=
Any
try
:
from
torch.distributed
import
broadcast_object_list
# noqa
_torch_broadcast_object
=
True
except
ImportError
:
from
.utils
import
broadcast_object
_torch_broadcast_object
=
False
class
OSS
(
Optimizer
):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
...
...
@@ -336,27 +327,20 @@ class OSS(Optimizer):
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
if
_torch_broadcast_object
:
# torch native object broadcast
dist
.
broadcast_object_list
([
local_cpu_state
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# legacy compatibility for old torch versions
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
# legacy compatibility for old torch versions
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if
_torch_broadcast_object
:
dist
.
broadcast_object_list
([
dummy_sync_tensor
],
src
=
global_rank
,
group
=
self
.
group
)
else
:
broadcast_object
(
torch
.
tensor
([
dummy_sync_tensor
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
broadcast_object
(
torch
.
tensor
([
dummy_sync_tensor
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""Collect all the state shards, in CPU memory."""
...
...
@@ -370,32 +354,21 @@ class OSS(Optimizer):
)
# Sync with other replicas
if
_torch_broadcast_object
:
# torch native object broadcast
dist
.
broadcast_object_list
([
0
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# legacy compatibility for old torch versions
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
else
:
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
if
_torch_broadcast_object
:
replica_state_l
=
[
0
]
dist
.
broadcast_object_list
(
replica_state_l
,
src
=
global_rank
,
group
=
self
.
group
)
replica_state
=
replica_state_l
[
0
]
else
:
replica_state
=
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
replica_state
=
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
,
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment