Unverified Commit 54bd62d3 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS dict load/save fix - better fix than 383 and unit test (#386)

* WIP, needs to be fixed !

* should be a fix, many thanks Weiyi Zheng

* slightly better unit test, sorting the states on the way out

* reproducing the issue from Weiyi in a unit test, and finally properly fixing

* fixing unit test on pytorch1.5 - original loss diff 26.404895782470703 - 26.404342651367188
parent b666d6a4
...@@ -379,6 +379,8 @@ class OSS(Optimizer): ...@@ -379,6 +379,8 @@ class OSS(Optimizer):
global_id = self.param_to_index[local_index_to_param_id[local_param_index]] global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index] state_dict["state"][global_id] = s["state"][local_param_index]
# Make sure that the parameters are sorted in the state, as expected
state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...@@ -391,23 +393,16 @@ class OSS(Optimizer): ...@@ -391,23 +393,16 @@ class OSS(Optimizer):
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time # NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups # we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}
for i_param, (key, value) in enumerate(state_dict["state"].items()): for key, value in state_dict["state"].items():
param = self.index_to_param[i_param] param = self.index_to_param[pytorch15_index_redirect[key]]
# Populate the sharded optimizer state on the fly # Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank: if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None state_dict["state"][key] = None
else:
if key in self.index_to_param: self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
param = self.index_to_param[i_param]
# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in [id(p) for p in pg["params"]]:
self.optim.state[param] = recursive_copy_to_device(
value, non_blocking=True, device=param.device
)
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
......
...@@ -774,7 +774,9 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -774,7 +774,9 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False): def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False):
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden)) trunk = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, hidden)
)
trunk.register_buffer("test_buffer", torch.ones((1)) * rank) trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
trunk.to(device) trunk.to(device)
...@@ -832,8 +834,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -832,8 +834,8 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))
assert torch.allclose( assert torch.allclose(
loss_ddp, loss_sharded_optim loss_ddp, loss_sharded_optim, rtol=1e-3
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" ), f"Losses differ in between Pytorch optim and OSS\n {loss_ddp.item()} - {loss_sharded_optim.item()} - world size {world_size}"
check_same_model_params(oss_ddp_model, ddp_model) check_same_model_params(oss_ddp_model, ddp_model)
...@@ -859,10 +861,16 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -859,10 +861,16 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device) sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device)
# - cross load the states # - cross load the states
# run one step and check that the models are still the same
ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) # OSS will remove some states
ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose ! ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose !
sharded_optimizer.load_state_dict(ddp_state_dict) sharded_optimizer.load_state_dict(ddp_state_dict)
check_step()
# - run one step and check that the models are still the same # - self load, rewind, check no problem
# run one step and check that the models are still the same
ddp_optimizer.load_state_dict(ddp_state_dict_ref)
sharded_optimizer.load_state_dict(sharded_optim_state_dict)
check_step() check_step()
for opt in [torch.optim.Adam, torch.optim.SGD]: for opt in [torch.optim.Adam, torch.optim.SGD]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment