Unverified Commit 38ad8638 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS: removing the torch broadcast util altogether, broken on 1.7.1 (#329)

* removing the torch util altogether, broken on 1.7.1
parent f5ab9a18
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import Workhandle, recursive_copy_to_device from .utils import Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover ...@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
else: else:
_params_t = Any _params_t = Any
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from .utils import broadcast_object
_torch_broadcast_object = False
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
...@@ -336,10 +327,6 @@ class OSS(Optimizer): ...@@ -336,10 +327,6 @@ class OSS(Optimizer):
logging.debug( logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank, "Sending the sharded optimizer state to the reference replica from rank %s", rank,
) )
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions # legacy compatibility for old torch versions
broadcast_object( broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
...@@ -348,9 +335,6 @@ class OSS(Optimizer): ...@@ -348,9 +335,6 @@ class OSS(Optimizer):
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if _torch_broadcast_object:
dist.broadcast_object_list([dummy_sync_tensor], src=global_rank, group=self.group)
else:
broadcast_object( broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device), torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
src_rank=global_rank, src_rank=global_rank,
...@@ -370,11 +354,6 @@ class OSS(Optimizer): ...@@ -370,11 +354,6 @@ class OSS(Optimizer):
) )
# Sync with other replicas # Sync with other replicas
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object( broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device), torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=self.global_rank, src_rank=self.global_rank,
...@@ -384,12 +363,6 @@ class OSS(Optimizer): ...@@ -384,12 +363,6 @@ class OSS(Optimizer):
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
if _torch_broadcast_object:
replica_state_l = [0]
dist.broadcast_object_list(replica_state_l, src=global_rank, group=self.group)
replica_state = replica_state_l[0]
else:
replica_state = broadcast_object( replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device), torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank, src_rank=global_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment