Unverified Commit 2d2412e2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Make OSS state available on all ranks (#500)

* extending the current state_dict interface, make it possible to do everything in a single call, and to checkpoint on all ranks
parent 8dc2030b
...@@ -297,19 +297,55 @@ class OSS(Optimizer): ...@@ -297,19 +297,55 @@ class OSS(Optimizer):
def consolidate_state_dict(self, recipient_rank: int = 0) -> None: def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank. """Update the consolidated state_dict list, one per rank.
Arguments:
recipient_rank (int): on which rank to materialize the full state dict.
-1 is a special value, which means that all ranks should have the state
.. warning: This needs to be called on all replicas""" .. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated # Sync lr and other attributes in case its been updated
OSS._sync_param_groups(self.param_groups, self.optim.param_groups) OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
if self.rank == recipient_rank: # Pull the sharded state from all the other replicas
# Pull the sharded state from all the other replicas # Store all the states in order, rank by rank
# Store all the states in order, rank by rank logging.debug("Pulling the sharded optimizer state from all replicas")
logging.debug("Pulling the sharded optimizer state from all replicas")
self._all_states = self._collect_sharded_states() self._all_states = []
else: should_collect_state = self.rank == recipient_rank or recipient_rank == -1
# Acknowledge broadcasts, and send this rank's shard when needed should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1
self._broadcast_state_dict()
for rank in range(self.world_size):
if rank == self.rank:
if should_collect_state:
logging.debug("Saving self state")
self._all_states.append(
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
state_to_share = (
self.optim.state_dict()
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=self._default_device)
)
broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=self._default_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
if should_collect_state:
self._all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
""" .. deprecated:: 0.1.5 """ .. deprecated:: 0.1.5
...@@ -325,29 +361,38 @@ class OSS(Optimizer): ...@@ -325,29 +361,38 @@ class OSS(Optimizer):
""" """
return self.optim.state_dict() return self.optim.state_dict()
def state_dict(self) -> Dict[str, Any]: def state_dict(self, all_ranks: bool = False) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. It contains two entries: sharded properties are not exposed.
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups Arguments:
all_ranks (bool): materialize the state on all ranks. In that case, `.state_dict()` needs to be called on
all ranks
.. warning: Returns:
If the state has not been consolidated, this returns a shard's worth, not the global state. a dict with two entries
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning: .. warning:
Returning the global state is limited to the replica which was responsible for the consolidation. Returning the global state is limited to the replica which was responsible for the consolidation,
The state may also not be up to date, depending on when `consolidate_state_dict` was last called. if `all_ranks` was not set to `True`. In that case, the state may also not be up to date,
depending on when `consolidate_state_dict` was last called.
""" """
if len(self._all_states) == 0: if not all_ranks and len(self._all_states) == 0:
raise RuntimeError( raise RuntimeError(
"Optimizer state has not been consolidated on this rank. \ "Optimizer state has not been consolidated on this rank. \
Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state" Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state"
) )
if all_ranks:
# Consolidate the state on every rank
self.consolidate_state_dict(recipient_rank=-1)
# Unify the shard states and the state that pytorch would expect, given the model. # Unify the shard states and the state that pytorch would expect, given the model.
# Indexation needs several redirections, since each shard only knows a limited scope of the model # Indexation needs several redirections, since each shard only knows a limited scope of the model
# - get the pytorch compliant parameter indexing # - get the pytorch compliant parameter indexing
...@@ -406,7 +451,6 @@ class OSS(Optimizer): ...@@ -406,7 +451,6 @@ class OSS(Optimizer):
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device) self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
else: else:
# Not a param, copied as-is (backward compatibility or exotic optimizers) # Not a param, copied as-is (backward compatibility or exotic optimizers)
print(key, "not in idmap")
param = _param_list[key] param = _param_list[key]
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device) self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
...@@ -430,69 +474,6 @@ class OSS(Optimizer): ...@@ -430,69 +474,6 @@ class OSS(Optimizer):
self._setup_flat_buffers() self._setup_flat_buffers()
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._default_device)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
)
else:
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""Collect all the state shards, in CPU memory."""
all_states = []
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)
return all_states
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`. """Add a param group to the :class:`Optimizer` s `param_groups`.
......
...@@ -24,6 +24,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -24,6 +24,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
from fairscale.utils.testing import ( from fairscale.utils.testing import (
check_same_model_params, check_same_model_params,
check_same_models_across_ranks,
skip_if_no_cuda, skip_if_no_cuda,
skip_if_py39_no_cuda, skip_if_py39_no_cuda,
skip_if_single_gpu, skip_if_single_gpu,
...@@ -448,6 +449,14 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -448,6 +449,14 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
for state in optimizer.state.values(): for state in optimizer.state.values():
for _, _ in state.items(): for _, _ in state.items():
pass pass
# Test the state dict materialization on all ranks
_ = optimizer.step(closure=closure)
optimizer_state_dict = optimizer.state_dict(all_ranks=True) # one per rank
optimizer.load_state_dict(optimizer_state_dict)
_ = optimizer.step(closure=closure)
check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -477,6 +486,7 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name): ...@@ -477,6 +486,7 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)) model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width))
model.to(device) model.to(device)
model = DDP(model, device_ids=[device])
loss_fn = torch.nn.L1Loss() loss_fn = torch.nn.L1Loss()
loss_fn.to(device) loss_fn.to(device)
...@@ -492,14 +502,9 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name): ...@@ -492,14 +502,9 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
# Update the optimizer state on the reference rank # Get a snapshot of the state at this point
optimizer.consolidate_state_dict(recipient_rank=reference_rank) optimizer_state_dict = copy.deepcopy(optimizer.state_dict(all_ranks=True))
model_state_dict = copy.deepcopy(model.state_dict())
# Fetch the state on the reference rank, broadcast to the other ones
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
else:
optimizer_state_dict = {}
# Run two steps, log the loss # Run two steps, log the loss
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
...@@ -507,18 +512,20 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name): ...@@ -507,18 +512,20 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
# Load the optimizer state dict, rewind the state two steps back # Load the optimizer state dict, rewind the state two steps back
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
model.load_state_dict(model_state_dict)
# Run two new steps, log the loss again and check that we get the same # Run two new steps, log the loss again and check that we get the same
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
test_loss = optimizer.step(closure=closure) test_loss = optimizer.step(closure=closure)
assert torch.allclose(reference_loss, test_loss) assert torch.allclose(reference_loss, test_loss), f"{reference_loss} vs {test_loss}. Reproducibility is broken"
dist.destroy_process_group() dist.destroy_process_group()
# TODO(blefaudeux) Fix for torch v1.8.0 # TODO(blefaudeux) Fix for torch v1.8.0
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0") @pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0")
@skip_if_single_gpu
def test_reproducibility(): def test_reproducibility():
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
...@@ -530,7 +537,7 @@ def test_reproducibility(): ...@@ -530,7 +537,7 @@ def test_reproducibility():
reference_rank = 0 reference_rank = 0
mp.spawn( mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True, run_test_reproducibility, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment