Unverified Commit 51294d90 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Fix optimizer support for Python <= 3.9 (#1379)

parent 776140a5
...@@ -173,7 +173,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -173,7 +173,7 @@ class Optimizer8bit(torch.optim.Optimizer):
raise ValueError("loaded state dict has a different number of parameter groups") raise ValueError("loaded state dict has a different number of parameter groups")
param_lens = (len(g["params"]) for g in groups) param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups) saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)): if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError( raise ValueError(
"loaded state dict contains a parameter group that doesn't match the size of optimizer's group", "loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
) )
...@@ -184,7 +184,6 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -184,7 +184,6 @@ class Optimizer8bit(torch.optim.Optimizer):
for old_id, p in zip( for old_id, p in zip(
chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups), chain.from_iterable(g["params"] for g in groups),
strict=True,
) )
} }
...@@ -226,7 +225,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -226,7 +225,7 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"] new_group["params"] = group["params"]
return new_group return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)] param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups}) self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self): def to_gpu(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment