Unverified Commit 02405740 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] oss and interleaved param groups (#483)

parent 64bbb6e1
...@@ -85,7 +85,6 @@ class OSS(Optimizer): ...@@ -85,7 +85,6 @@ class OSS(Optimizer):
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {} self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None self._local_params: Optional[List[torch.Tensor]] = None
...@@ -160,15 +159,6 @@ class OSS(Optimizer): ...@@ -160,15 +159,6 @@ class OSS(Optimizer):
# Make sure that the iterator is not consumed, only expose a copy # Make sure that the iterator is not consumed, only expose a copy
return self._local_params return self._local_params
@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._index_to_param) == 0:
self._index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._index_to_param
@property @property
def param_to_index(self) -> Dict[int, int]: def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """ Hash table in between parameter indices in the global optimizer scheme, and the actual params
...@@ -376,7 +366,7 @@ class OSS(Optimizer): ...@@ -376,7 +366,7 @@ class OSS(Optimizer):
global_id = self.param_to_index[local_index_to_param_id[local_param_index]] global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index] state_dict["state"][global_id] = s["state"][local_param_index]
# Make sure that the parameters are sorted in the state, as expected # Make sure that the parameters are sorted in the state, as expected for a pytorch dict
state_dict["state"] = dict(sorted(state_dict["state"].items())) state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict return state_dict
...@@ -389,17 +379,34 @@ class OSS(Optimizer): ...@@ -389,17 +379,34 @@ class OSS(Optimizer):
from a call to :meth:`state_dict` from a call to :meth:`state_dict`
""" """
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time # Update the state, trusting the ordering in param_groups
# we work around that here by using the fact that the params are ordered as in the param_groups # Apart from the removal of states not owned by this rank, the pytorch logic is kept
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())} # (See torch.optim.optimizer)
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in state_dict["param_groups"])),
chain.from_iterable((g["params"] for g in self.param_groups)),
)
}
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list = list(chain.from_iterable((g["params"] for g in self.param_groups)))
for key, value in state_dict["state"].items(): for key, value in state_dict["state"].items():
param = self.index_to_param[pytorch15_index_redirect[key]] if key in id_map:
param = id_map[key]
# Populate the sharded optimizer state on the fly # Populate the sharded optimizer state on the fly,
if self.param_to_rank[param] != self.rank: # remove the params that this rank does not own
state_dict["state"][key] = None if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None
else:
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
else: else:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
print(key, "not in idmap")
param = _param_list[key]
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device) self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
...@@ -515,7 +522,6 @@ class OSS(Optimizer): ...@@ -515,7 +522,6 @@ class OSS(Optimizer):
self._partition_parameters.clear() self._partition_parameters.clear()
self._per_device_params.clear() self._per_device_params.clear()
self._param_rank.clear() self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear() self._param_to_index.clear()
self._local_params = None self._local_params = None
......
...@@ -22,7 +22,13 @@ import torch.multiprocessing as mp ...@@ -22,7 +22,13 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu from fairscale.utils.testing import (
check_same_model_params,
skip_if_no_cuda,
skip_if_py39_no_cuda,
skip_if_single_gpu,
torch_version,
)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
...@@ -811,9 +817,11 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -811,9 +817,11 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# Define a model to be trained by OSS # Define a model to be trained by OSS
oss_module = torch.nn.Sequential(trunk, head) oss_module = torch.nn.Sequential(trunk, head)
# Make sure that the param groups are interleaved, to catch an ordering bug in the state dict
oss_trainable_params = [ oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5}, {"params": list(trunk.parameters())[:-1] + list(head.parameters()), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4}, {"params": list(trunk.parameters())[-1], "lr": 1e-4},
] ]
optimizer_settings: Dict[Any, Any] = {} optimizer_settings: Dict[Any, Any] = {}
...@@ -836,8 +844,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -836,8 +844,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head) ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head)
ddp_trainable_params = [ ddp_trainable_params = [
{"params": ddp_trunk.parameters(), "lr": 1e-5}, {"params": list(ddp_trunk.parameters())[:-1] + list(ddp_head.parameters()), "lr": 1e-5},
{"params": ddp_head.parameters(), "lr": 1e-4}, {"params": list(ddp_trunk.parameters())[-1], "lr": 1e-4},
] ]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
...@@ -880,25 +888,26 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -880,25 +888,26 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
next(oss_module.parameters()).requires_grad = not next(oss_module.parameters()).requires_grad next(oss_module.parameters()).requires_grad = not next(oss_module.parameters()).requires_grad
# sharded_optimizer.refresh_trainable() # sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible # Check that the checkpoints are compatible (post pytorch 1.5)
# - get states if torch_version()[1] > 5:
ddp_state_dict = ddp_optimizer.state_dict() # - get states
sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) ddp_state_dict = ddp_optimizer.state_dict()
sharded_optim_state_dict = sharded_optimizer.state_dict() if rank == RECIPIENT_RANK else {} sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK)
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device) sharded_optim_state_dict = sharded_optimizer.state_dict() if rank == RECIPIENT_RANK else {}
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device)
# - cross load the states
# run one step and check that the models are still the same # - cross load the states
ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) # OSS will remove some states # run one step and check that the models are still the same
ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose ! ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) # OSS will remove some states
sharded_optimizer.load_state_dict(ddp_state_dict) ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose !
check_step() sharded_optimizer.load_state_dict(ddp_state_dict)
check_step()
# - self load, rewind, check no problem
# run one step and check that the models are still the same # - self load, rewind, check no problem
ddp_optimizer.load_state_dict(ddp_state_dict_ref) # run one step and check that the models are still the same
sharded_optimizer.load_state_dict(sharded_optim_state_dict) ddp_optimizer.load_state_dict(ddp_state_dict_ref)
check_step() sharded_optimizer.load_state_dict(sharded_optim_state_dict)
check_step()
for opt in [torch.optim.Adam, torch.optim.SGD]: for opt in [torch.optim.Adam, torch.optim.SGD]:
check_optimizer_equivalence(opt, change_train_graph=False) check_optimizer_equivalence(opt, change_train_graph=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment