Unverified Commit b666d6a4 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

Revert "[fix] oss dict load (#383)" (#384)

This reverts commit 8be9d930.
parent 8be9d930
...@@ -391,16 +391,16 @@ class OSS(Optimizer): ...@@ -391,16 +391,16 @@ class OSS(Optimizer):
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time # NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups # we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}
for key, value in state_dict["state"].items(): for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[pytorch15_index_redirect[key]] param = self.index_to_param[i_param]
# Populate the sharded optimizer state on the fly # Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank: if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None state_dict["state"][key] = None
else: if key in self.index_to_param:
param = self.index_to_param[i_param]
# Only add this state to the sharded optimizer if it owns this param # Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups: for pg in self.optim.param_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment