Unverified Commit 38ad8638 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS: removing the torch broadcast util altogether, broken on 1.7.1 (#329)

* removing the torch util altogether, broken on 1.7.1
parent f5ab9a18
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import Workhandle, recursive_copy_to_device from .utils import Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover ...@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
else: else:
_params_t = Any _params_t = Any
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from .utils import broadcast_object
_torch_broadcast_object = False
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
...@@ -336,27 +327,20 @@ class OSS(Optimizer): ...@@ -336,27 +327,20 @@ class OSS(Optimizer):
logging.debug( logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank, "Sending the sharded optimizer state to the reference replica from rank %s", rank,
) )
if _torch_broadcast_object: # legacy compatibility for old torch versions
# torch native object broadcast broadcast_object(
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group) self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
else: )
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else: else:
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if _torch_broadcast_object: broadcast_object(
dist.broadcast_object_list([dummy_sync_tensor], src=global_rank, group=self.group) torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
else: src_rank=global_rank,
broadcast_object( group=self.group,
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device), dist_device=self._device,
src_rank=global_rank, )
group=self.group,
dist_device=self._device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory.""" """Collect all the state shards, in CPU memory."""
...@@ -370,32 +354,21 @@ class OSS(Optimizer): ...@@ -370,32 +354,21 @@ class OSS(Optimizer):
) )
# Sync with other replicas # Sync with other replicas
if _torch_broadcast_object: broadcast_object(
# torch native object broadcast torch.tensor([0], dtype=torch.uint8, device=self._device),
dist.broadcast_object_list([0], src=self.global_rank, group=self.group) src_rank=self.global_rank,
else: group=self.group,
# legacy compatibility for old torch versions dist_device=self._device,
broadcast_object( )
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
)
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object(
if _torch_broadcast_object: torch.tensor([0], dtype=torch.uint8, device=self._device),
replica_state_l = [0] src_rank=global_rank,
dist.broadcast_object_list(replica_state_l, src=global_rank, group=self.group) group=self.group,
replica_state = replica_state_l[0] dist_device=self._device,
else: )
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
all_states.append( all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment