Commit 33082d2b authored by Kexin Yu's avatar Kexin Yu
Browse files

fix attribute name mismatch in state_dict() and load_state_dict()

parent 858d7899
...@@ -196,7 +196,7 @@ class FP16_Optimizer(object): ...@@ -196,7 +196,7 @@ class FP16_Optimizer(object):
state_dict['scale_factor'] = self.scale_factor state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat state_dict['fp32_groups'] = self.fp32_groups
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
...@@ -238,5 +238,5 @@ class FP16_Optimizer(object): ...@@ -238,5 +238,5 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params # constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params. # are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']): for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):
current.data.copy_(saved.data) current.data.copy_(saved.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment