Commit 8e363567 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by Mandeep Singh Baines
Browse files

[feat] Implement OSS save and load of the sharded state from a single replica (#16)

parent bfba68d8
...@@ -4,11 +4,15 @@ ...@@ -4,11 +4,15 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy import copy
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
import torch
import torch.distributed as dist import torch.distributed as dist
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.optim.optimizer import _params_t from torch.optim.optimizer import _params_t
else: else:
...@@ -17,7 +21,7 @@ else: ...@@ -17,7 +21,7 @@ else:
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as describe by ZeRO_. optimizer and shards its state as described by ZeRO_.
:: ::
opt = OSS(params, optim=torch.optim.Adam, lr=0.01) opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
...@@ -54,6 +58,12 @@ class OSS(Optimizer): ...@@ -54,6 +58,12 @@ class OSS(Optimizer):
param_groups = self.partition_parameters() param_groups = self.partition_parameters()
self.optim = optim(param_groups[self.rank], **defaults) self.optim = optim(param_groups[self.rank], **defaults)
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks. """Partitions parameters across distributed ranks.
...@@ -73,10 +83,10 @@ class OSS(Optimizer): ...@@ -73,10 +83,10 @@ class OSS(Optimizer):
param_lists[rank].append(param) param_lists[rank].append(param)
sizes[rank] += param.numel() sizes[rank] += param.numel()
for rank, params in enumerate(param_lists): for rank, params in enumerate(param_lists):
if len(params): if len(params) > 0:
pg = copy.copy(param_group) param_group_rank = copy.copy(param_group)
pg["params"] = params param_group_rank["params"] = params
param_groups[rank].append(pg) param_groups[rank].append(param_group_rank)
return param_groups return param_groups
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
...@@ -87,13 +97,50 @@ class OSS(Optimizer): ...@@ -87,13 +97,50 @@ class OSS(Optimizer):
dist.broadcast(param, rank, group=self.group) dist.broadcast(param, rank, group=self.group)
return loss return loss
def state_dict(self) -> dict: def local_state_dict(self) -> dict:
""" Gets this rank's state_dict. """ """ Gets this rank's state_dict. """
return self.optim.state_dict() return self.optim.state_dict()
def load_state_dict(self, state_dict: dict) -> None: def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
""" Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas """
if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
logging.debug("Pulling the sharded SGD state from all replicas")
self._all_states = self._collect_sharded_states()
else:
# Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict()
def state_dict(self) -> Dict[str, Any]:
"""
Return the last known global optimizer state, which consist of a list of the shards.
NOTE: This is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
assert (
len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return {"states": self._all_states}
def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """ """ Loads this rank's state_dict. """
self.optim.load_state_dict(state_dict)
# Make sure that the state is on the appropriate device
state_dict_ondevice = recursive_copy_to_device(state_dict, non_blocking=False, device=self._device)
self.optim.load_state_dict(state_dict_ondevice)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Loads this rank's optimizer state_dict, given the global optimizer state. """
# Dispatch this rank's state dictionary to the local load
self.load_local_state_dict(state_dict["states"][self.rank])
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
...@@ -101,3 +148,52 @@ class OSS(Optimizer): ...@@ -101,3 +148,52 @@ class OSS(Optimizer):
param_groups = self.partition_parameters()[self.rank] param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""
Collect all the state shards, in CPU memory.
"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []
for rank in range(dist.get_world_size(group=self.group)):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
else:
# Fetch the optim state from the other replicas
logging.debug("Receiving state from rank %s ", rank)
replica_state = broadcast_object(
empty_buffer, src_rank=rank, group=self.group, dist_device=self._device
)
all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def _broadcast_state_dict(self) -> None:
"""
Broadcast this rank's state shard, discard others
"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(dist.get_world_size(group=self.group)):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded SGD state to the reference replica from rank %s", rank,
)
broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device)
else:
# Discard this tensor/rank, broadcast necessary for syncing
logging.debug("Discarding broadcast from rank %s", rank)
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import io
from typing import Any, Dict
import torch
from torch._six import container_abcs
import torch.distributed as dist
# Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
"""
Recursively searches lists, tuples, dicts and copies tensors to device if
possible. Non-tensor values are passed as-is in the result.
NOTE: These are all copies, so if there are two objects that reference
the same object, then after this call, there will be two different objects
referenced on the device.
"""
if isinstance(value, torch.Tensor):
return value.to(device, non_blocking=non_blocking)
if isinstance(value, (list, tuple)):
values = []
for val in value:
values.append(recursive_copy_to_device(val, non_blocking=non_blocking, device=device))
return values if isinstance(value, list) else tuple(values)
if isinstance(value, container_abcs.Mapping):
device_val: Dict[str, Any] = {}
for key, val in value.items():
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
return device_val
return value
def broadcast_object(
obj: Any, src_rank: int, group: object = dist.group.WORLD, dist_device: torch.device = torch.device("cpu")
) -> Any:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer) # type: ignore
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0]).to(dist_device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device) # type: ignore
return obj
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import os import os
import pytest import pytest
...@@ -14,17 +18,20 @@ import fairscale.optim as optim ...@@ -14,17 +18,20 @@ import fairscale.optim as optim
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
def setup_module(module): def setup_module(module):
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500" os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl", rank=0, world_size=1) dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def dist_init(rank, world_size): def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501" os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def test_create(): def test_create():
...@@ -32,17 +39,29 @@ def test_create(): ...@@ -32,17 +39,29 @@ def test_create():
o = optim.OSS(params, lr=0.01) o = optim.OSS(params, lr=0.01)
@skip_if_no_cuda
def test_state_dict(): def test_state_dict():
x = torch.tensor([1.0], device="cuda", requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.1)
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict() state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict) o.load_state_dict(state_dict)
# We should now be using a lr of 0.1. # We should now be using a lr of 0.1.
x.backward() x.backward()
o.step() o.step()
assert x == torch.tensor([0.9], device="cuda") assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict()
o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size): def run_test_add_param_group(rank, world_size):
...@@ -57,9 +76,9 @@ def run_test_add_param_group(rank, world_size): ...@@ -57,9 +76,9 @@ def run_test_add_param_group(rank, world_size):
# Verify that added group is added to the correct partition making all have 8 elements. # Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
if rank == 1: if rank == 1:
len(o.optim.param_groups) == 2 assert len(o.optim.param_groups) == 2
else: else:
len(o.optim.param_groups) == 1 assert len(o.optim.param_groups) == 1
def test_add_param_group(): def test_add_param_group():
...@@ -81,7 +100,6 @@ def run_test_zero_grad(rank, world_size): ...@@ -81,7 +100,6 @@ def run_test_zero_grad(rank, world_size):
assert not m.bias.grad assert not m.bias.grad
@skip_if_no_cuda
def test_zero_grad(): def test_zero_grad():
world_size = 2 world_size = 2
mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True) mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True)
...@@ -111,8 +129,9 @@ def test_step(): ...@@ -111,8 +129,9 @@ def test_step():
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True) mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True)
def run_test_step_with_closure(rank, world_size): def run_test_step_with_closure(rank, world_size, optimizer=None):
dist_init(rank, world_size) dist_init(rank, world_size)
x_val = rank + 1 x_val = rank + 1
weight = 1.0 weight = 1.0
bias = 2.0 bias = 2.0
...@@ -125,7 +144,9 @@ def run_test_step_with_closure(rank, world_size): ...@@ -125,7 +144,9 @@ def run_test_step_with_closure(rank, world_size):
m.weight.data = torch.tensor([[weight]]) m.weight.data = torch.tensor([[weight]])
m.bias.data = torch.tensor([bias]) m.bias.data = torch.tensor([bias])
m.to(rank) m.to(rank)
o = optim.OSS(m.parameters(), lr=0.1) o = optim.OSS(m.parameters(), lr=0.1)
y = m(x) y = m(x)
y.backward(x) y.backward(x)
for p in m.parameters(): for p in m.parameters():
...@@ -164,3 +185,59 @@ def run_test_sharding(rank, world_size): ...@@ -164,3 +185,59 @@ def run_test_sharding(rank, world_size):
def test_sharding(): def test_sharding():
world_size = 3 world_size = 3
mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True) mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True)
def run_test_collect_shards(rank, world_size, reference_rank):
dist_init(rank, world_size)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width))
model.to(device)
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
# With SGD, Momentum is required to get a state to shard
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99)
def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss.backward()
return loss
_ = optimizer.step(closure=closure)
# Update the optimizer state on the reference rank
optimizer.consolidate_state_dict(recipient_rank=reference_rank)
# Fetch the state on the reference rank
# - check that it has the correct size
# - load it again
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
assert len(optimizer_state_dict["states"]) == world_size
else:
optimizer_state_dict = {}
optimizer_state_dict = optim.utils.broadcast_object(
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
# Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict)
def test_collect_shards():
world_size = 3
reference_rank = 0
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment