Unverified Commit b0e6b9bd authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS to 100% coverage (#618)

parent d9f36130
...@@ -301,20 +301,6 @@ class OSS(Optimizer): ...@@ -301,20 +301,6 @@ class OSS(Optimizer):
logging.debug("State from rank %s received", rank) logging.debug("State from rank %s received", rank)
def local_state_dict(self) -> dict:
""" .. deprecated:: 0.1.5
Returns this rank's state_dict as a :class:`dict` which contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
.. warning: This does not represent the optimizer state dict, only a shard.
"""
return self.optim.state_dict()
def state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: def state_dict(self, all_ranks: bool = False) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. sharded properties are not exposed.
...@@ -390,11 +376,7 @@ class OSS(Optimizer): ...@@ -390,11 +376,7 @@ class OSS(Optimizer):
) )
} }
# FIXME: pytorch1.5 compatibility, to be removed when 1.5 support ends
_param_list = list(chain.from_iterable((g["params"] for g in self.param_groups)))
for key, value in state_dict["state"].items(): for key, value in state_dict["state"].items():
if key in id_map:
param = id_map[key] param = id_map[key]
# Populate the sharded optimizer state on the fly, # Populate the sharded optimizer state on the fly,
...@@ -403,10 +385,6 @@ class OSS(Optimizer): ...@@ -403,10 +385,6 @@ class OSS(Optimizer):
state_dict["state"][key] = {} state_dict["state"][key] = {}
else: else:
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device) self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
else:
# Not a param, copied as-is (backward compatibility or exotic optimizers)
param = _param_list[key]
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment