Unverified Commit 3399e97c authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][OSS] Removing ad-hoc object broadcast, use pytorch's (#297)

parent 9faad392
......@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device
from .utils import Bucket, Workhandle, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -320,6 +320,55 @@ class OSS(Optimizer):
# Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict()
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.local_state_dict(), non_blocking=True, device=torch.device("cpu")
)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
dist.broadcast_object_list([0], src=global_rank, group=self.group)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
all_states = []
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = [0]
dist.broadcast_object_list(replica_state, src=global_rank, group=self.group)
all_states.append(
recursive_copy_to_device(replica_state[0], non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state, which consist of a list of the shards.
......@@ -466,53 +515,6 @@ class OSS(Optimizer):
elif k in global_group.keys():
local_group[k] = global_group[k]
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
broadcast_object(empty_buffer, src_rank=self.global_rank, group=self.group, dist_device=self._device)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object(
empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
)
all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
......
......@@ -3,12 +3,10 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import io
from typing import Any, Callable, Dict, List, Optional
import torch
from torch._six import container_abcs
import torch.distributed as dist
class Workhandle:
......@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return device_val
return value
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
......@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def broadcast_object_list(object_list: List[Any], src: int, group:Optional[ProcessGroup] = None): ...
def is_initialized() -> bool: ...
......
......@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
else:
optimizer_state_dict = {}
optimizer_state_dict = optim.utils.broadcast_object(
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
optim_state = [optimizer_state_dict]
dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
# Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict)
optimizer.load_state_dict(optim_state[0])
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment