Unverified Commit 9e8929e6 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][OSS] elastic and pytorch compatible checkpoints (#310)

* adding a test to prove the inter operability with upstream pytorch
* updating the changelog
* eager state pruning
* pytorch 1.5 compat
parent c2dd6c34
...@@ -25,3 +25,4 @@ venv/ ...@@ -25,3 +25,4 @@ venv/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
.vscode/*
...@@ -6,8 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,8 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [next rel] - TBD ## [next rel] - TBD
### Added ### Added
- Pytorch compatibility for OSS checkpoints (#310)
- Elastic checkpoints for OSS, world size can vary in between save and loads (#310)
- Tensor views for OSS bucketing, reduced CPU use (#300)
- Bucket calls in ShardedDDP, for faster inter node communications (#327) - Bucket calls in ShardedDDP, for faster inter node communications (#327)
- Tensor views for OSS bucketing, reduced CPU use
## [0.1.4] - 2021-01-07 ## [0.1.4] - 2021-01-07
### Fixed ### Fixed
......
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
from collections import OrderedDict, deque from collections import OrderedDict, deque
import copy import copy
import itertools
from itertools import chain from itertools import chain
import logging import logging
from math import inf from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -80,6 +79,8 @@ class OSS(Optimizer): ...@@ -80,6 +79,8 @@ class OSS(Optimizer):
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}
# Build the wrapped optimizer, responsible for a shard of the params # Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
...@@ -142,6 +143,24 @@ class OSS(Optimizer): ...@@ -142,6 +143,24 @@ class OSS(Optimizer):
return self._partition_parameters return self._partition_parameters
@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._index_to_param) == 0:
self._index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._index_to_param
@property
def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._param_to_index) == 0:
self._param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._param_to_index
@property @property
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]: def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank. """Sorted list of all the params, first per device then per rank.
...@@ -191,7 +210,7 @@ class OSS(Optimizer): ...@@ -191,7 +210,7 @@ class OSS(Optimizer):
.. note: Any extra parameter is passed to the base optimizer as-is""" .. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler. # Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups() OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Run the optimizer step on this shard only: # Run the optimizer step on this shard only:
if closure is not None: if closure is not None:
...@@ -203,7 +222,7 @@ class OSS(Optimizer): ...@@ -203,7 +222,7 @@ class OSS(Optimizer):
self._broadcast_params() self._broadcast_params()
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True) OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
return loss return loss
...@@ -237,7 +256,7 @@ class OSS(Optimizer): ...@@ -237,7 +256,7 @@ class OSS(Optimizer):
norm_type = float(norm_type) norm_type = float(norm_type)
# Filter out the grad-less params, concatenate params from all devices # Filter out the grad-less params, concatenate params from all devices
local_params = itertools.chain( local_params = chain(
*[ *[
list(filter(lambda x: x.grad is not None, device_params[self.rank])) list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values() for device_params in self.per_device_params.values()
...@@ -280,26 +299,13 @@ class OSS(Optimizer): ...@@ -280,26 +299,13 @@ class OSS(Optimizer):
return total_norm return total_norm
# State dict interfaces # State dict interfaces
def local_state_dict(self) -> dict:
"""Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return self.optim.state_dict()
def consolidate_state_dict(self, recipient_rank: int = 0) -> None: def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank. """Update the consolidated state_dict list, one per rank.
.. warning: This needs to be called on all replicas""" .. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated # Sync lr and other attributes in case its been updated
self._sync_param_groups() OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
if self.rank == recipient_rank: if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas # Pull the sharded state from all the other replicas
...@@ -310,12 +316,104 @@ class OSS(Optimizer): ...@@ -310,12 +316,104 @@ class OSS(Optimizer):
# Acknowledge broadcasts, and send this rank's shard when needed # Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict() self._broadcast_state_dict()
def local_state_dict(self) -> dict:
""" .. deprecated:: 0.1.5
Returns this rank's state_dict as a :class:`dict` which contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning: This does not represent the optimizer state dict, only a shard.
"""
return self.optim.state_dict()
def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
if len(self._all_states) == 0:
raise RuntimeError(
"Optimizer state has not been consolidated on this rank. \
Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state"
)
# Unify the shard states and the state that pytorch would expect, given the model.
# Indexation needs several redirections, since each shard only knows a limited scope of the model
# - get the pytorch compliant parameter indexing
state_dict = super().state_dict()
# - go through the per-shard states, which are all indexed locally
for rank, s in enumerate(self._all_states):
# -- match the local indexing and the global partition, update the corresponding saved state globally
for local_pg, global_pg in zip(s["param_groups"], self.partition_parameters()[rank]):
local_index_to_param_id = {
i_param: id(global_pg["params"][i]) for i, i_param in enumerate(local_pg["params"])
}
for local_param_index in local_pg["params"]:
# Update the state, if any
if local_param_index in s["state"].keys():
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[i_param]
# Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None
if key in self.index_to_param:
param = self.index_to_param[i_param]
# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in [id(p) for p in pg["params"]]:
self.optim.state[param] = recursive_copy_to_device(
value, non_blocking=True, device=param.device
)
super().load_state_dict(state_dict)
# Sync with the optimizer param groups
OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def _broadcast_state_dict(self) -> None: def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others""" """Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom # Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device( local_cpu_state = recursive_copy_to_device(
self.local_state_dict(), non_blocking=True, device=torch.device("cpu") self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
) )
# Tensor cannot be really empty, even if its size is meaningless # Tensor cannot be really empty, even if its size is meaningless
...@@ -350,7 +448,7 @@ class OSS(Optimizer): ...@@ -350,7 +448,7 @@ class OSS(Optimizer):
if rank == self.rank: if rank == self.rank:
logging.debug("Saving self state") logging.debug("Saving self state")
all_states.append( all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu")) recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
) )
# Sync with other replicas # Sync with other replicas
...@@ -378,103 +476,6 @@ class OSS(Optimizer): ...@@ -378,103 +476,6 @@ class OSS(Optimizer):
return all_states return all_states
def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
if len(self._all_states) == 0:
logging.warning("Optimizer state has not been consolidated. Returning the local state")
logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state")
state_dict = self.local_state_dict()
state_dict["local_state_dict"] = True
return state_dict
# Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
param_groups: List[Dict[Any, Any]] = []
start = 0
for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
end = start + len(s["param_groups"])
partition.append((start, end))
start = end
return {
"state": [s["state"] for s in self._all_states],
"param_groups": param_groups,
"partition": partition,
"local_state_dict": False,
}
@staticmethod
def rank_local_state_dict(rank: int, state_dict: dict) -> dict:
"""Returns the local_state_dict for a given rank.
Arguments:
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
"""
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups}
def load_local_state_dict(self, state_dict: dict) -> None:
"""Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""
self.optim.load_state_dict(state_dict)
# Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706)
# Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
groups = self.optim.param_groups
saved_groups = state_dict["param_groups"]
id_map = {
old_id: p
for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups)))
}
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)
# Restore the global param_groups (the params themselves are already correct)
for global_group, local_group in zip(self.param_groups, groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
# Force a re-partitioning, in case the model changed with the new state
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# Check whether we got a local or global dict
if "local_state_dict" in state_dict and state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
else:
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(OSS.rank_local_state_dict(self.rank, state_dict))
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`. """Add a param group to the :class:`Optimizer` s `param_groups`.
...@@ -491,10 +492,9 @@ class OSS(Optimizer): ...@@ -491,10 +492,9 @@ class OSS(Optimizer):
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
# Force a re-partitioning # Force a re-partitioning
self._partition_parameters.clear() self._clear_cache()
self._per_device_params.clear()
self._param_rank.clear()
# Update the partition
param_groups = self.partition_parameters()[self.rank] param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
...@@ -502,6 +502,13 @@ class OSS(Optimizer): ...@@ -502,6 +502,13 @@ class OSS(Optimizer):
# Update the bucketing strategy accordingly # Update the bucketing strategy accordingly
self._setup_bucket_strategy() self._setup_bucket_strategy()
def _clear_cache(self) -> None:
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear()
@staticmethod @staticmethod
def get_global_rank(group: Any, rank: int) -> int: def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD: if group is dist.group.WORLD:
...@@ -510,20 +517,14 @@ class OSS(Optimizer): ...@@ -510,20 +517,14 @@ class OSS(Optimizer):
global_rank = dist.distributed_c10d._get_global_rank(group, rank) global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank return global_rank
@torch.no_grad() @staticmethod
def _sync_param_groups(self, local_to_global: bool = False) -> None: def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers). """Sync learning rate and other optimizer attributes (needed to support schedulers)."""
If the global param groups have been altered, and we want to make sure that the
wrapped optimizer uses the up to date version.
Conversely if the wrapped optimizer has new keys, we expose them through the global param groups"""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for source_group, destination_group in zip(source, destination):
# Sync everything but the parameters # Sync everything but the parameters
for k in filter(lambda x: x != "params", local_group.keys()): for k in filter(lambda x: x != "params", source_group.keys()):
if local_to_global: destination_group[k] = source_group[k]
global_group[k] = local_group[k]
elif k in global_group.keys():
local_group[k] = global_group[k]
@torch.no_grad() @torch.no_grad()
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
...@@ -614,7 +615,6 @@ class OSS(Optimizer): ...@@ -614,7 +615,6 @@ class OSS(Optimizer):
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten()) self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data) param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
offset = offset_next offset = offset_next
else: else:
self.should_bucket_param.append(False) self.should_bucket_param.append(False)
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
import copy import copy
from math import inf from math import inf
import tempfile import tempfile
from typing import Type, cast from typing import Any, Type, cast
import unittest import unittest
import numpy as np import numpy as np
...@@ -26,6 +26,7 @@ from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu ...@@ -26,6 +26,7 @@ from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
RECIPIENT_RANK = 1
try: try:
from torch.distributed import broadcast_object_list # noqa from torch.distributed import broadcast_object_list # noqa
...@@ -42,6 +43,19 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND): ...@@ -42,6 +43,19 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch.device) -> Any:
if _torch_broadcast_object:
package = [something_to_sync]
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0]
else:
package_sync = optim.utils.broadcast_object(
something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
return package_sync
class TestSingleRank(unittest.TestCase): class TestSingleRank(unittest.TestCase):
""" """
All the following tests do not check for inter-process communication All the following tests do not check for inter-process communication
...@@ -158,31 +172,11 @@ class TestSingleRank(unittest.TestCase): ...@@ -158,31 +172,11 @@ class TestSingleRank(unittest.TestCase):
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict()
o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict(self): def test_implicit_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict() with pytest.raises(RuntimeError):
o = optim.OSS([x], lr=0.01) _ = o.state_dict()
o.load_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size, tempfile_name): def run_test_add_param_group(rank, world_size, tempfile_name):
...@@ -348,7 +342,10 @@ def test_step_with_closure(): ...@@ -348,7 +342,10 @@ def test_step_with_closure():
def run_test_sharding(rank, world_size, tempfile_name): def run_test_sharding(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name) dist_init(rank, world_size, tempfile_name)
params = [] params = []
for size in [5, 4, 2, 6, 4, 3]: sizes = [9, 7, 5, 3]
sizes_world = sizes * world_size
for size in sizes_world:
params.append(torch.rand(size, 1)) params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning # Make sure that the params are trainable, enforces size-based partitioning
...@@ -356,17 +353,17 @@ def run_test_sharding(rank, world_size, tempfile_name): ...@@ -356,17 +353,17 @@ def run_test_sharding(rank, world_size, tempfile_name):
p.requires_grad = True p.requires_grad = True
o = optim.OSS(params, lr=0.1) o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8 assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == sum(sizes)
dist.destroy_process_group() dist.destroy_process_group()
def test_sharding(): def test_sharding():
world_size = 3 world_size = 4
if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: if torch.cuda.is_available():
pytest.skip("Not enough GPUs for NCCL-based test") world_size = min(world_size, torch.cuda.device_count())
temp_file_name = tempfile.mkstemp()[1]
_, temp_file_name = tempfile.mkstemp()
mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True)
...@@ -405,18 +402,12 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -405,18 +402,12 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
# - load it again # - load it again
if rank == reference_rank: if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict() optimizer_state_dict = optimizer.state_dict()
assert len(optimizer_state_dict["state"]) == world_size assert len(optimizer_state_dict["state"]) == len(list(model.parameters()))
else: else:
optimizer_state_dict = {} optimizer_state_dict = {}
optim_state = [optimizer_state_dict] # distribute to the other ranks
if _torch_broadcast_object: optimizer_state_dict = sync_object_ranks(optimizer_state_dict, reference_rank, device)
dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
optimizer_state_dict = optim_state[0]
else:
optimizer_state_dict = optim.utils.broadcast_object(
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
# Load the optimizer state dict # Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
...@@ -436,6 +427,72 @@ def test_collect_shards(): ...@@ -436,6 +427,72 @@ def test_collect_shards():
) )
def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 3, 3, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width))
model.to(device)
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1)
def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss.backward()
return loss
_ = optimizer.step(closure=closure)
# Update the optimizer state on the reference rank
optimizer.consolidate_state_dict(recipient_rank=reference_rank)
# Fetch the state on the reference rank, broadcast to the other ones
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
else:
optimizer_state_dict = {}
# Run two steps, log the loss
_ = optimizer.step(closure=closure)
reference_loss = optimizer.step(closure=closure)
# Load the optimizer state dict, rewind the state two steps back
optimizer.load_state_dict(optimizer_state_dict)
# Run two new steps, log the loss again and check that we get the same
_ = optimizer.step(closure=closure)
test_loss = optimizer.step(closure=closure)
assert torch.allclose(reference_loss, test_loss)
dist.destroy_process_group()
def test_reproducibility():
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
# Bail out if not enough devices
return
reference_rank = 0
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
)
def run_test_multiple_groups(rank, world_size, tempfile_name): def run_test_multiple_groups(rank, world_size, tempfile_name):
# Only work with the even ranks, to check that the global_rank indexing is properly used # Only work with the even ranks, to check that the global_rank indexing is properly used
dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo") dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo")
...@@ -593,15 +650,17 @@ def test_gradient_clipping(): ...@@ -593,15 +650,17 @@ def test_gradient_clipping():
def run_state_dict_distributed(rank, world_size, tempfile_name): def run_state_dict_distributed(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo") dist_init(rank, world_size, tempfile_name, backend="gloo")
device = torch.device(rank) device = torch.device(rank)
torch.manual_seed(rank) # make sure that the different rank get different data torch.manual_seed(rank) # make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists # Setup two problems in parallel, we'll make sure that the second track (with save/load) follows the first one(untouched)
# We split the model in two to test the multiple param groups support
batch, input_width, hidden, target_width = 3, 20, 10, 5 batch, input_width, hidden, target_width = 3, 20, 10, 5
target = torch.rand((batch, target_width), device=device) target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device) inputs = torch.rand((batch, input_width), device=device)
model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden),).to(device) model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden)).to(device)
head_oss1 = torch.nn.Linear(hidden, target_width).to(device) head_oss1 = torch.nn.Linear(hidden, target_width).to(device)
model_oss2 = copy.deepcopy(model_oss1) model_oss2 = copy.deepcopy(model_oss1)
...@@ -619,48 +678,36 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -619,48 +678,36 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
def run_grad_step(device, model, head, optimizer): loss_fn = torch.nn.L1Loss().to(device)
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
def run_grad_step(model, head, optimizer):
model.zero_grad() model.zero_grad()
outputs = head(model(inputs)) outputs = head(model(inputs))
loss = loss_fn(outputs, target) def check_equal_models(message: str):
loss.backward() for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), message
optimizer.step() # pull the current state, broadcast it to all ranks
optimizer.zero_grad() sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
# save and reload without taking any steps state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) # re-create a new optimizer from scratch with absurd values, load the previous state
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2) sharded_optimizer2.load_state_dict(state_dict2)
check_equal_models("parameters of the two identical models have diverged (before any steps)")
# now take a step and check that parameters are equal # now take a step and check that parameters are equal
# take a step run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) check_equal_models("parameters of the two identical models have diverged (after stepping)")
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before any steps)"
# take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that model parameters are equal # save the state dict for one model only, then distribute to the other ranks
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()): sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before saving)" state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)
# save the state dict for one model only
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
# Check that the pulled state and the .param_groups attribute are in sync # Check that the pulled state and the .param_groups attribute are in sync
for replica in range(len(state_dict2["param_groups"])): for replica in range(len(state_dict2["param_groups"])):
...@@ -669,18 +716,14 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -669,18 +716,14 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
assert state_dict2["param_groups"][replica][k] == sharded_optimizer2.param_groups[0][k] assert state_dict2["param_groups"][replica][k] == sharded_optimizer2.param_groups[0][k]
# take a step # take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after consolidating)")
# check that saving did not cause a change in the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(
param1, param2
), "parameters of the two identical models have diverged (after consolidating)"
# save again # save again for one rank, then distribute to the others
sharded_optimizer2.consolidate_state_dict() sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)
# reload the state_dict # reload the state_dict
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
...@@ -688,23 +731,20 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -688,23 +731,20 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2.load_state_dict(state_dict2) sharded_optimizer2.load_state_dict(state_dict2)
# take a step # take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2) run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after reloading)")
# check that reloading a saved state dict does not change the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (after reloading)"
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
def test_state_dict_distributed(): def test_state_dict_distributed():
world_size = 8 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available(): if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count()) world_size = max(world_size, torch.cuda.device_count())
mp.spawn( mp.spawn(
run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True, run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True,
...@@ -719,19 +759,51 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -719,19 +759,51 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
hidden = 5
in_channels = 3
out_channels = 3
batch = 64
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]): def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3),) trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden))
model.register_buffer("test_buffer", torch.ones((1)) * rank) trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) trunk.to(device)
head = torch.nn.Linear(hidden, out_channels).to(device)
# Define a model to be trained by OSS
oss_model = torch.nn.Sequential(trunk, head)
oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4},
]
optimizer_settings = {}
if isinstance(optim, torch.optim.SGD):
optimizer_settings["momentum"] = 0.9
sharded_optimizer = optim.OSS(
params=oss_trainable_params,
optim=optimizer,
group=None,
broadcast_buffer_size=2 ** 10,
**optimizer_settings,
)
oss_ddp_model = DDP(module=oss_model, device_ids=[rank], broadcast_buffers=True)
sharded_optimizer = optim.OSS(params=model.parameters(), optim=optimizer, lr=1e-3) # Define a model to be trained by normal pytorch + DDP
sharded_ddp_model = DDP(module=model, device_ids=[rank], broadcast_buffers=True) ddp_trunk = copy.deepcopy(trunk)
ddp_head = copy.deepcopy(head)
ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head)
ddp_model_single = copy.deepcopy(model) ddp_trainable_params = [
ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) {"params": ddp_trunk.parameters(), "lr": 1e-5},
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) {"params": ddp_head.parameters(), "lr": 1e-4},
]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params(): def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
...@@ -740,17 +812,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -740,17 +812,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
p, ddp_p, atol=1e-3 p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}" ), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): for b, ddp_b in zip(oss_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose( assert torch.allclose(
b, ddp_b b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}" ), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
# The model should be synchronized in between the ranks at construction time, check that def check_step():
check_same_model_params() input_tensor = torch.rand((batch, in_channels)).to(device)
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor): def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad() ddp_optimizer.zero_grad()
...@@ -760,7 +828,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -760,7 +828,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
def closure_sharded(input_tensor=input_tensor): def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad() sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss = oss_ddp_model(input_tensor).abs().sum()
sharded_loss.backward() sharded_loss.backward()
return sharded_loss return sharded_loss
...@@ -771,8 +839,29 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -771,8 +839,29 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
loss_ddp, loss_sharded_optim loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between ddp and sharded optimizer
for i in range(5):
check_step()
check_same_model_params() check_same_model_params()
# Check that the checkpoints are compatible
# - get states
ddp_state_dict = ddp_optimizer.state_dict()
sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK)
sharded_optim_state_dict = sharded_optimizer.state_dict() if rank == RECIPIENT_RANK else {}
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device)
# - cross load the states
ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose !
sharded_optimizer.load_state_dict(ddp_state_dict)
# - run one step and check that the models are still the same
check_step()
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]: for opt in [torch.optim.SGD, torch.optim.Adam]:
check_optimizer_equivalence(opt) check_optimizer_equivalence(opt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment