Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
38ad8638
Unverified
Commit
38ad8638
authored
Jan 27, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 27, 2021
Browse files
[fix] OSS: removing the torch broadcast util altogether, broken on 1.7.1 (#329)
* removing the torch util altogether, broken on 1.7.1
parent
f5ab9a18
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
50 deletions
+23
-50
fairscale/optim/oss.py
fairscale/optim/oss.py
+23
-50
No files found.
fairscale/optim/oss.py
View file @
38ad8638
...
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
.utils
import
Workhandle
,
recursive_copy_to_device
from
.utils
import
Workhandle
,
broadcast_object
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
...
...
@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
else
:
_params_t
=
Any
try
:
from
torch.distributed
import
broadcast_object_list
# noqa
_torch_broadcast_object
=
True
except
ImportError
:
from
.utils
import
broadcast_object
_torch_broadcast_object
=
False
class
OSS
(
Optimizer
):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
...
...
@@ -336,27 +327,20 @@ class OSS(Optimizer):
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
if
_torch_broadcast_object
:
# torch native object broadcast
dist
.
broadcast_object_list
([
local_cpu_state
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# legacy compatibility for old torch versions
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
# legacy compatibility for old torch versions
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if
_torch_broadcast_object
:
dist
.
broadcast_object_list
([
dummy_sync_tensor
],
src
=
global_rank
,
group
=
self
.
group
)
else
:
broadcast_object
(
torch
.
tensor
([
dummy_sync_tensor
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
broadcast_object
(
torch
.
tensor
([
dummy_sync_tensor
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""Collect all the state shards, in CPU memory."""
...
...
@@ -370,32 +354,21 @@ class OSS(Optimizer):
)
# Sync with other replicas
if
_torch_broadcast_object
:
# torch native object broadcast
dist
.
broadcast_object_list
([
0
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# legacy compatibility for old torch versions
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
else
:
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
if
_torch_broadcast_object
:
replica_state_l
=
[
0
]
dist
.
broadcast_object_list
(
replica_state_l
,
src
=
global_rank
,
group
=
self
.
group
)
replica_state
=
replica_state_l
[
0
]
else
:
replica_state
=
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
replica_state
=
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
,
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
,
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment