Unverified Commit b666d6a4 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

Revert "[fix] oss dict load (#383)" (#384)

This reverts commit 8be9d930.
parent 8be9d930
......@@ -391,16 +391,16 @@ class OSS(Optimizer):
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}
for key, value in state_dict["state"].items():
param = self.index_to_param[pytorch15_index_redirect[key]]
for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[i_param]
# Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None
else:
if key in self.index_to_param:
param = self.index_to_param[i_param]
# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment