"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "2d3aaef8b63a08d3e17cdd30d1cbb851e2898f9c"
Unverified Commit 8be9d930 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] oss dict load (#383)

* many thanks Weiyi Zheng
parent 13445c55
...@@ -391,16 +391,16 @@ class OSS(Optimizer): ...@@ -391,16 +391,16 @@ class OSS(Optimizer):
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time # NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups # we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}
for i_param, (key, value) in enumerate(state_dict["state"].items()): for key, value in state_dict["state"].items():
param = self.index_to_param[i_param] param = self.index_to_param[pytorch15_index_redirect[key]]
# Populate the sharded optimizer state on the fly # Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank: if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None state_dict["state"][key] = None
if key in self.index_to_param: else:
param = self.index_to_param[i_param]
# Only add this state to the sharded optimizer if it owns this param # Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups: for pg in self.optim.param_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment