Unverified Commit 494f8ab3 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Fix attribute name mismatch in state_dict() and load_state_dict() (#704)

* updated apex.contrib.optimizers.FP16_Optimizer and FusedSGD

* fix attribute name mismatch in state_dict() and load_state_dict()
parent 2ca894da
......@@ -196,7 +196,7 @@ class FP16_Optimizer(object):
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
state_dict['fp32_groups'] = self.fp32_groups
return state_dict
def load_state_dict(self, state_dict):
......@@ -238,5 +238,5 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):
current.data.copy_(saved.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment