Unverified Commit 02405740 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] oss and interleaved param groups (#483)

parent 64bbb6e1
......@@ -85,7 +85,6 @@ class OSS(Optimizer):
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None
......@@ -160,15 +159,6 @@ class OSS(Optimizer):
# Make sure that the iterator is not consumed, only expose a copy
return self._local_params
@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._index_to_param) == 0:
self._index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}
return self._index_to_param
@property
def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
......@@ -376,7 +366,7 @@ class OSS(Optimizer):
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]
# Make sure that the parameters are sorted in the state, as expected
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict
......@@ -389,18 +379,35 @@ class OSS(Optimizer):
from a call to :meth:`state_dict`
"""
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}
# Update the state, trusting the ordering in param_groups
# Apart from the removal of states not owned by this rank, the pytorch logic is kept
# (See torch.optim.optimizer)
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in state_dict["param_groups"])),
chain.from_iterable((g["params"] for g in self.param_groups)),
)
}
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list = list(chain.from_iterable((g["params"] for g in self.param_groups)))
for key, value in state_dict["state"].items():
param = self.index_to_param[pytorch15_index_redirect[key]]
if key in id_map:
param = id_map[key]
# Populate the sharded optimizer state on the fly
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None
else:
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
else:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
print(key, "not in idmap")
param = _param_list[key]
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
super().load_state_dict(state_dict)
......@@ -515,7 +522,6 @@ class OSS(Optimizer):
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear()
self._local_params = None
......
......@@ -22,7 +22,13 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu
from fairscale.utils.testing import (
check_same_model_params,
skip_if_no_cuda,
skip_if_py39_no_cuda,
skip_if_single_gpu,
torch_version,
)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
......@@ -811,9 +817,11 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# Define a model to be trained by OSS
oss_module = torch.nn.Sequential(trunk, head)
# Make sure that the param groups are interleaved, to catch an ordering bug in the state dict
oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4},
{"params": list(trunk.parameters())[:-1] + list(head.parameters()), "lr": 1e-5},
{"params": list(trunk.parameters())[-1], "lr": 1e-4},
]
optimizer_settings: Dict[Any, Any] = {}
......@@ -836,8 +844,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head)
ddp_trainable_params = [
{"params": ddp_trunk.parameters(), "lr": 1e-5},
{"params": ddp_head.parameters(), "lr": 1e-4},
{"params": list(ddp_trunk.parameters())[:-1] + list(ddp_head.parameters()), "lr": 1e-5},
{"params": list(ddp_trunk.parameters())[-1], "lr": 1e-4},
]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
......@@ -880,7 +888,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
next(oss_module.parameters()).requires_grad = not next(oss_module.parameters()).requires_grad
# sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible
# Check that the checkpoints are compatible (post pytorch 1.5)
if torch_version()[1] > 5:
# - get states
ddp_state_dict = ddp_optimizer.state_dict()
sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment