Unverified Commit 3399e97c authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][OSS] Removing ad-hoc object broadcast, use pytorch's (#297)

parent 9faad392
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device from .utils import Bucket, Workhandle, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -320,6 +320,55 @@ class OSS(Optimizer): ...@@ -320,6 +320,55 @@ class OSS(Optimizer):
# Acknowledge broadcasts, and send this rank's shard when needed # Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict() self._broadcast_state_dict()
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.local_state_dict(), non_blocking=True, device=torch.device("cpu")
)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
dist.broadcast_object_list([local_cpu_state], src=self.global_rank, group=self.group)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
dist.broadcast_object_list([0], src=global_rank, group=self.group)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
all_states = []
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
dist.broadcast_object_list([0], src=self.global_rank, group=self.group)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = [0]
dist.broadcast_object_list(replica_state, src=global_rank, group=self.group)
all_states.append(
recursive_copy_to_device(replica_state[0], non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state, which consist of a list of the shards. """Return the last known global optimizer state, which consist of a list of the shards.
...@@ -466,53 +515,6 @@ class OSS(Optimizer): ...@@ -466,53 +515,6 @@ class OSS(Optimizer):
elif k in global_group.keys(): elif k in global_group.keys():
local_group[k] = global_group[k] local_group[k] = global_group[k]
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
broadcast_object(empty_buffer, src_rank=self.global_rank, group=self.group, dist_device=self._device)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object(
empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
)
all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
......
...@@ -3,12 +3,10 @@ ...@@ -3,12 +3,10 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import io
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch._six import container_abcs from torch._six import container_abcs
import torch.distributed as dist
class Workhandle: class Workhandle:
...@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic ...@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return device_val return device_val
return value return value
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
...@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ... ...@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ... def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ... def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ... def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def broadcast_object_list(object_list: List[Any], src: int, group:Optional[ProcessGroup] = None): ...
def is_initialized() -> bool: ... def is_initialized() -> bool: ...
......
...@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
else: else:
optimizer_state_dict = {} optimizer_state_dict = {}
optimizer_state_dict = optim.utils.broadcast_object( optim_state = [optimizer_state_dict]
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
)
# Load the optimizer state dict # Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optim_state[0])
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment