Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
38ad8638
Unverified
Commit
38ad8638
authored
Jan 27, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 27, 2021
Browse files
[fix] OSS: removing the torch broadcast util altogether, broken on 1.7.1 (#329)
* removing the torch util altogether, broken on 1.7.1
parent
f5ab9a18
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
50 deletions
+23
-50
fairscale/optim/oss.py
fairscale/optim/oss.py
+23
-50
No files found.
fairscale/optim/oss.py
View file @
38ad8638
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
.utils
import
Workhandle
,
recursive_copy_to_device
from
.utils
import
Workhandle
,
broadcast_object
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
...
@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
else
:
else
:
_params_t
=
Any
_params_t
=
Any
try
:
from
torch.distributed
import
broadcast_object_list
# noqa
_torch_broadcast_object
=
True
except
ImportError
:
from
.utils
import
broadcast_object
_torch_broadcast_object
=
False
class
OSS
(
Optimizer
):
class
OSS
(
Optimizer
):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
...
@@ -336,10 +327,6 @@ class OSS(Optimizer):
...
@@ -336,10 +327,6 @@ class OSS(Optimizer):
logging
.
debug
(
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
)
if
_torch_broadcast_object
:
# torch native object broadcast
dist
.
broadcast_object_list
([
local_cpu_state
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# legacy compatibility for old torch versions
# legacy compatibility for old torch versions
broadcast_object
(
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
...
@@ -348,9 +335,6 @@ class OSS(Optimizer):
...
@@ -348,9 +335,6 @@ class OSS(Optimizer):
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if
_torch_broadcast_object
:
dist
.
broadcast_object_list
([
dummy_sync_tensor
],
src
=
global_rank
,
group
=
self
.
group
)
else
:
broadcast_object
(
broadcast_object
(
torch
.
tensor
([
dummy_sync_tensor
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
torch
.
tensor
([
dummy_sync_tensor
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
src_rank
=
global_rank
,
...
@@ -370,11 +354,6 @@ class OSS(Optimizer):
...
@@ -370,11 +354,6 @@ class OSS(Optimizer):
)
)
# Sync with other replicas
# Sync with other replicas
if
_torch_broadcast_object
:
# torch native object broadcast
dist
.
broadcast_object_list
([
0
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# legacy compatibility for old torch versions
broadcast_object
(
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
self
.
global_rank
,
src_rank
=
self
.
global_rank
,
...
@@ -384,12 +363,6 @@ class OSS(Optimizer):
...
@@ -384,12 +363,6 @@ class OSS(Optimizer):
else
:
else
:
# Fetch the optim state from the other replicas
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
if
_torch_broadcast_object
:
replica_state_l
=
[
0
]
dist
.
broadcast_object_list
(
replica_state_l
,
src
=
global_rank
,
group
=
self
.
group
)
replica_state
=
replica_state_l
[
0
]
else
:
replica_state
=
broadcast_object
(
replica_state
=
broadcast_object
(
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
),
src_rank
=
global_rank
,
src_rank
=
global_rank
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment