Unverified Commit 2d2412e2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Make OSS state available on all ranks (#500)

* extending the current state_dict interface, make it possible to do everything in a single call, and to checkpoint on all ranks
parent 8dc2030b
......@@ -297,19 +297,55 @@ class OSS(Optimizer):
def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank.
Arguments:
recipient_rank (int): on which rank to materialize the full state dict.
-1 is a special value, which means that all ranks should have the state
.. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
logging.debug("Pulling the sharded optimizer state from all replicas")
self._all_states = self._collect_sharded_states()
self._all_states = []
should_collect_state = self.rank == recipient_rank or recipient_rank == -1
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1
for rank in range(self.world_size):
if rank == self.rank:
if should_collect_state:
logging.debug("Saving self state")
self._all_states.append(
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
state_to_share = (
self.optim.state_dict()
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=self._default_device)
)
broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=self._default_device,
)
else:
# Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict()
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
if should_collect_state:
self._all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
def local_state_dict(self) -> dict:
""" .. deprecated:: 0.1.5
......@@ -325,29 +361,38 @@ class OSS(Optimizer):
"""
return self.optim.state_dict()
def state_dict(self) -> Dict[str, Any]:
def state_dict(self, all_ranks: bool = False) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. It contains two entries:
sharded properties are not exposed.
Arguments:
all_ranks (bool): materialize the state on all ranks. In that case, `.state_dict()` needs to be called on
all ranks
Returns:
a dict with two entries
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
Returning the global state is limited to the replica which was responsible for the consolidation,
if `all_ranks` was not set to `True`. In that case, the state may also not be up to date,
depending on when `consolidate_state_dict` was last called.
"""
if len(self._all_states) == 0:
if not all_ranks and len(self._all_states) == 0:
raise RuntimeError(
"Optimizer state has not been consolidated on this rank. \
Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state"
)
if all_ranks:
# Consolidate the state on every rank
self.consolidate_state_dict(recipient_rank=-1)
# Unify the shard states and the state that pytorch would expect, given the model.
# Indexation needs several redirections, since each shard only knows a limited scope of the model
# - get the pytorch compliant parameter indexing
......@@ -406,7 +451,6 @@ class OSS(Optimizer):
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
else:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
print(key, "not in idmap")
param = _param_list[key]
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
......@@ -430,69 +474,6 @@ class OSS(Optimizer):
self._setup_flat_buffers()
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._default_device)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
)
else:
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
all_states = []
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
......
......@@ -24,6 +24,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
from fairscale.utils.testing import (
check_same_model_params,
check_same_models_across_ranks,
skip_if_no_cuda,
skip_if_py39_no_cuda,
skip_if_single_gpu,
......@@ -448,6 +449,14 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
for state in optimizer.state.values():
for _, _ in state.items():
pass
# Test the state dict materialization on all ranks
_ = optimizer.step(closure=closure)
optimizer_state_dict = optimizer.state_dict(all_ranks=True) # one per rank
optimizer.load_state_dict(optimizer_state_dict)
_ = optimizer.step(closure=closure)
check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False)
dist.destroy_process_group()
......@@ -477,6 +486,7 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width))
model.to(device)
model = DDP(model, device_ids=[device])
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
......@@ -492,14 +502,9 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
_ = optimizer.step(closure=closure)
# Update the optimizer state on the reference rank
optimizer.consolidate_state_dict(recipient_rank=reference_rank)
# Fetch the state on the reference rank, broadcast to the other ones
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
else:
optimizer_state_dict = {}
# Get a snapshot of the state at this point
optimizer_state_dict = copy.deepcopy(optimizer.state_dict(all_ranks=True))
model_state_dict = copy.deepcopy(model.state_dict())
# Run two steps, log the loss
_ = optimizer.step(closure=closure)
......@@ -507,18 +512,20 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
# Load the optimizer state dict, rewind the state two steps back
optimizer.load_state_dict(optimizer_state_dict)
model.load_state_dict(model_state_dict)
# Run two new steps, log the loss again and check that we get the same
_ = optimizer.step(closure=closure)
test_loss = optimizer.step(closure=closure)
assert torch.allclose(reference_loss, test_loss)
assert torch.allclose(reference_loss, test_loss), f"{reference_loss} vs {test_loss}. Reproducibility is broken"
dist.destroy_process_group()
# TODO(blefaudeux) Fix for torch v1.8.0
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0")
@skip_if_single_gpu
def test_reproducibility():
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
......@@ -530,7 +537,7 @@ def test_reproducibility():
reference_rank = 0
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
run_test_reproducibility, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment