Unverified Commit 38ad8638 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS: removing the torch broadcast util altogether, broken on 1.7.1 (#329)

* removing the torch util altogether, broken on 1.7.1
parent f5ab9a18
......@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import Workhandle, recursive_copy_to_device
from .utils import Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -25,15 +25,6 @@ if TYPE_CHECKING: # pragma: no cover
else:
_params_t = Any
try:
from torch.distributed import broadcast_object_list # noqa
_torch_broadcast_object = True
except ImportError:
from .utils import broadcast_object
_torch_broadcast_object = False
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
......@@ -336,27 +327,20 @@ class OSS(Optimizer):
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if _torch_broadcast_object:
dist.broadcast_object_list([dummy_sync_tensor], src=global_rank, group=self.group)
else:
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
......@@ -370,32 +354,21 @@ class OSS(Optimizer):
)
# Sync with other replicas
if _torch_broadcast_object:
# torch native object broadcast
dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
else:
# legacy compatibility for old torch versions
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
)
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
if _torch_broadcast_object:
replica_state_l = [0]
dist.broadcast_object_list(replica_state_l, src=global_rank, group=self.group)
replica_state = replica_state_l[0]
else:
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
)
all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment