Unverified Commit 9e8929e6 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][OSS] elastic and pytorch compatible checkpoints (#310)

* adding a test to prove the inter operability with upstream pytorch
* updating the changelog
* eager state pruning
* pytorch 1.5 compat
parent c2dd6c34
......@@ -25,3 +25,4 @@ venv/
ENV/
env.bak/
venv.bak/
.vscode/*
......@@ -6,8 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [next rel] - TBD
### Added
- Pytorch compatibility for OSS checkpoints (#310)
- Elastic checkpoints for OSS, world size can vary in between save and loads (#310)
- Tensor views for OSS bucketing, reduced CPU use (#300)
- Bucket calls in ShardedDDP, for faster inter node communications (#327)
- Tensor views for OSS bucketing, reduced CPU use
## [0.1.4] - 2021-01-07
### Fixed
......
......@@ -5,11 +5,10 @@
from collections import OrderedDict, deque
import copy
import itertools
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
import torch
import torch.distributed as dist
......@@ -80,6 +79,8 @@ class OSS(Optimizer):
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
......@@ -142,6 +143,24 @@ class OSS(Optimizer):
return self._partition_parameters
@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._index_to_param) == 0:
self._index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._index_to_param
@property
def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._param_to_index) == 0:
self._param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._param_to_index
@property
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
......@@ -191,7 +210,7 @@ class OSS(Optimizer):
.. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Run the optimizer step on this shard only:
if closure is not None:
......@@ -203,7 +222,7 @@ class OSS(Optimizer):
self._broadcast_params()
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
return loss
......@@ -237,7 +256,7 @@ class OSS(Optimizer):
norm_type = float(norm_type)
# Filter out the grad-less params, concatenate params from all devices
local_params = itertools.chain(
local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
......@@ -280,26 +299,13 @@ class OSS(Optimizer):
return total_norm
# State dict interfaces
def local_state_dict(self) -> dict:
"""Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return self.optim.state_dict()
def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank.
.. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
self._sync_param_groups()
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas
......@@ -310,12 +316,104 @@ class OSS(Optimizer):
# Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict()
def local_state_dict(self) -> dict:
""" .. deprecated:: 0.1.5
Returns this rank's state_dict as a :class:`dict` which contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning: This does not represent the optimizer state dict, only a shard.
"""
return self.optim.state_dict()
def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
if len(self._all_states) == 0:
raise RuntimeError(
"Optimizer state has not been consolidated on this rank. \
Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state"
)
# Unify the shard states and the state that pytorch would expect, given the model.
# Indexation needs several redirections, since each shard only knows a limited scope of the model
# - get the pytorch compliant parameter indexing
state_dict = super().state_dict()
# - go through the per-shard states, which are all indexed locally
for rank, s in enumerate(self._all_states):
# -- match the local indexing and the global partition, update the corresponding saved state globally
for local_pg, global_pg in zip(s["param_groups"], self.partition_parameters()[rank]):
local_index_to_param_id = {
i_param: id(global_pg["params"][i]) for i, i_param in enumerate(local_pg["params"])
}
for local_param_index in local_pg["params"]:
# Update the state, if any
if local_param_index in s["state"].keys():
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[i_param]
# Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None
if key in self.index_to_param:
param = self.index_to_param[i_param]
# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in [id(p) for p in pg["params"]]:
self.optim.state[param] = recursive_copy_to_device(
value, non_blocking=True, device=param.device
)
super().load_state_dict(state_dict)
# Sync with the optimizer param groups
OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.local_state_dict(), non_blocking=True, device=torch.device("cpu")
self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
)
# Tensor cannot be really empty, even if its size is meaningless
......@@ -350,7 +448,7 @@ class OSS(Optimizer):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
......@@ -378,103 +476,6 @@ class OSS(Optimizer):
return all_states
def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
if len(self._all_states) == 0:
logging.warning("Optimizer state has not been consolidated. Returning the local state")
logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state")
state_dict = self.local_state_dict()
state_dict["local_state_dict"] = True
return state_dict
# Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
param_groups: List[Dict[Any, Any]] = []
start = 0
for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
end = start + len(s["param_groups"])
partition.append((start, end))
start = end
return {
"state": [s["state"] for s in self._all_states],
"param_groups": param_groups,
"partition": partition,
"local_state_dict": False,
}
@staticmethod
def rank_local_state_dict(rank: int, state_dict: dict) -> dict:
"""Returns the local_state_dict for a given rank.
Arguments:
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
"""
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups}
def load_local_state_dict(self, state_dict: dict) -> None:
"""Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""
self.optim.load_state_dict(state_dict)
# Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706)
# Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
groups = self.optim.param_groups
saved_groups = state_dict["param_groups"]
id_map = {
old_id: p
for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups)))
}
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)
# Restore the global param_groups (the params themselves are already correct)
for global_group, local_group in zip(self.param_groups, groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
# Force a re-partitioning, in case the model changed with the new state
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# Check whether we got a local or global dict
if "local_state_dict" in state_dict and state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
else:
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(OSS.rank_local_state_dict(self.rank, state_dict))
def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
......@@ -491,10 +492,9 @@ class OSS(Optimizer):
super().add_param_group(param_group)
if not self.in_super_constructor:
# Force a re-partitioning
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._clear_cache()
# Update the partition
param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])
......@@ -502,6 +502,13 @@ class OSS(Optimizer):
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
def _clear_cache(self) -> None:
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear()
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
......@@ -510,20 +517,14 @@ class OSS(Optimizer):
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
@torch.no_grad()
def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
wrapped optimizer uses the up to date version.
Conversely if the wrapped optimizer has new keys, we expose them through the global param groups"""
@staticmethod
def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for source_group, destination_group in zip(source, destination):
# Sync everything but the parameters
for k in filter(lambda x: x != "params", local_group.keys()):
if local_to_global:
global_group[k] = local_group[k]
elif k in global_group.keys():
local_group[k] = global_group[k]
for k in filter(lambda x: x != "params", source_group.keys()):
destination_group[k] = source_group[k]
@torch.no_grad()
def _broadcast_params(self) -> None:
......@@ -614,7 +615,6 @@ class OSS(Optimizer):
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
offset = offset_next
else:
self.should_bucket_param.append(False)
......
......@@ -11,7 +11,7 @@
import copy
from math import inf
import tempfile
from typing import Type, cast
from typing import Any, Type, cast
import unittest
import numpy as np
......@@ -26,6 +26,7 @@ from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
RECIPIENT_RANK = 1
try:
from torch.distributed import broadcast_object_list # noqa
......@@ -42,6 +43,19 @@ def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch.device) -> Any:
if _torch_broadcast_object:
package = [something_to_sync]
dist.broadcast_object_list(package, src=reference_rank, group=dist.group.WORLD)
package_sync = package[0]
else:
package_sync = optim.utils.broadcast_object(
something_to_sync, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
return package_sync
class TestSingleRank(unittest.TestCase):
"""
All the following tests do not check for inter-process communication
......@@ -158,31 +172,11 @@ class TestSingleRank(unittest.TestCase):
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict()
o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01)
o.load_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
with pytest.raises(RuntimeError):
_ = o.state_dict()
def run_test_add_param_group(rank, world_size, tempfile_name):
......@@ -348,7 +342,10 @@ def test_step_with_closure():
def run_test_sharding(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = []
for size in [5, 4, 2, 6, 4, 3]:
sizes = [9, 7, 5, 3]
sizes_world = sizes * world_size
for size in sizes_world:
params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
......@@ -356,17 +353,17 @@ def run_test_sharding(rank, world_size, tempfile_name):
p.requires_grad = True
o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == sum(sizes)
dist.destroy_process_group()
def test_sharding():
world_size = 3
if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
pytest.skip("Not enough GPUs for NCCL-based test")
temp_file_name = tempfile.mkstemp()[1]
world_size = 4
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
_, temp_file_name = tempfile.mkstemp()
mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True)
......@@ -405,18 +402,12 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
# - load it again
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
assert len(optimizer_state_dict["state"]) == world_size
assert len(optimizer_state_dict["state"]) == len(list(model.parameters()))
else:
optimizer_state_dict = {}
optim_state = [optimizer_state_dict]
if _torch_broadcast_object:
dist.broadcast_object_list(optim_state, src=reference_rank, group=dist.group.WORLD)
optimizer_state_dict = optim_state[0]
else:
optimizer_state_dict = optim.utils.broadcast_object(
optimizer_state_dict, src_rank=reference_rank, group=dist.group.WORLD, dist_device=device
)
# distribute to the other ranks
optimizer_state_dict = sync_object_ranks(optimizer_state_dict, reference_rank, device)
# Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict)
......@@ -436,6 +427,72 @@ def test_collect_shards():
)
def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 3, 3, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width))
model.to(device)
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1)
def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss.backward()
return loss
_ = optimizer.step(closure=closure)
# Update the optimizer state on the reference rank
optimizer.consolidate_state_dict(recipient_rank=reference_rank)
# Fetch the state on the reference rank, broadcast to the other ones
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
else:
optimizer_state_dict = {}
# Run two steps, log the loss
_ = optimizer.step(closure=closure)
reference_loss = optimizer.step(closure=closure)
# Load the optimizer state dict, rewind the state two steps back
optimizer.load_state_dict(optimizer_state_dict)
# Run two new steps, log the loss again and check that we get the same
_ = optimizer.step(closure=closure)
test_loss = optimizer.step(closure=closure)
assert torch.allclose(reference_loss, test_loss)
dist.destroy_process_group()
def test_reproducibility():
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
# Bail out if not enough devices
return
reference_rank = 0
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
)
def run_test_multiple_groups(rank, world_size, tempfile_name):
# Only work with the even ranks, to check that the global_rank indexing is properly used
dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo")
......@@ -593,15 +650,17 @@ def test_gradient_clipping():
def run_state_dict_distributed(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
device = torch.device(rank)
torch.manual_seed(rank) # make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists
# Setup two problems in parallel, we'll make sure that the second track (with save/load) follows the first one(untouched)
# We split the model in two to test the multiple param groups support
batch, input_width, hidden, target_width = 3, 20, 10, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden),).to(device)
model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden)).to(device)
head_oss1 = torch.nn.Linear(hidden, target_width).to(device)
model_oss2 = copy.deepcopy(model_oss1)
......@@ -619,48 +678,36 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
def run_grad_step(device, model, head, optimizer):
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
loss_fn = torch.nn.L1Loss().to(device)
def run_grad_step(model, head, optimizer):
model.zero_grad()
outputs = head(model(inputs))
loss = loss_fn(outputs, target)
loss.backward()
def check_equal_models(message: str):
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), message
optimizer.step()
optimizer.zero_grad()
# save and reload without taking any steps
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
# pull the current state, broadcast it to all ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
# re-create a new optimizer from scratch with absurd values, load the previous state
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2)
check_equal_models("parameters of the two identical models have diverged (before any steps)")
# now take a step and check that parameters are equal
# take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before any steps)"
# take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after stepping)")
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before saving)"
# save the state dict for one model only
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
# save the state dict for one model only, then distribute to the other ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)
# Check that the pulled state and the .param_groups attribute are in sync
for replica in range(len(state_dict2["param_groups"])):
......@@ -669,18 +716,14 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
assert state_dict2["param_groups"][replica][k] == sharded_optimizer2.param_groups[0][k]
# take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that saving did not cause a change in the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(
param1, param2
), "parameters of the two identical models have diverged (after consolidating)"
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after consolidating)")
# save again
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
# save again for one rank, then distribute to the others
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)
# reload the state_dict
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
......@@ -688,23 +731,20 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2.load_state_dict(state_dict2)
# take a step
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that reloading a saved state dict does not change the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (after reloading)"
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after reloading)")
dist.destroy_process_group()
@skip_if_no_cuda
def test_state_dict_distributed():
world_size = 8
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
world_size = max(world_size, torch.cuda.device_count())
mp.spawn(
run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True,
......@@ -719,19 +759,51 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
hidden = 5
in_channels = 3
out_channels = 3
batch = 64
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
# Any model works. Add one different buffer per rank
model = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3),)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden))
trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
trunk.to(device)
head = torch.nn.Linear(hidden, out_channels).to(device)
# Define a model to be trained by OSS
oss_model = torch.nn.Sequential(trunk, head)
oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4},
]
optimizer_settings = {}
if isinstance(optim, torch.optim.SGD):
optimizer_settings["momentum"] = 0.9
sharded_optimizer = optim.OSS(
params=oss_trainable_params,
optim=optimizer,
group=None,
broadcast_buffer_size=2 ** 10,
**optimizer_settings,
)
oss_ddp_model = DDP(module=oss_model, device_ids=[rank], broadcast_buffers=True)
sharded_optimizer = optim.OSS(params=model.parameters(), optim=optimizer, lr=1e-3)
sharded_ddp_model = DDP(module=model, device_ids=[rank], broadcast_buffers=True)
# Define a model to be trained by normal pytorch + DDP
ddp_trunk = copy.deepcopy(trunk)
ddp_head = copy.deepcopy(head)
ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
ddp_trainable_params = [
{"params": ddp_trunk.parameters(), "lr": 1e-5},
{"params": ddp_head.parameters(), "lr": 1e-4},
]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
......@@ -740,17 +812,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
for b, ddp_b in zip(oss_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
def check_step():
input_tensor = torch.rand((batch, in_channels)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
......@@ -760,7 +828,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss = oss_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
......@@ -771,8 +839,29 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between ddp and sharded optimizer
for i in range(5):
check_step()
check_same_model_params()
# Check that the checkpoints are compatible
# - get states
ddp_state_dict = ddp_optimizer.state_dict()
sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK)
sharded_optim_state_dict = sharded_optimizer.state_dict() if rank == RECIPIENT_RANK else {}
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device)
# - cross load the states
ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose !
sharded_optimizer.load_state_dict(ddp_state_dict)
# - run one step and check that the models are still the same
check_step()
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]:
check_optimizer_equivalence(opt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment