Commit 9eedf896 authored by mohammad's avatar mohammad
Browse files

some small fixes

parent 983cc311
...@@ -217,7 +217,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -217,7 +217,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
model_grads = [] model_grads = []
main_grads = [] main_grads = []
for model_group, main_group in zip(self.fp16_groups, for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups): self.fp32_from_fp16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if model_param.grad is not None: if model_param.grad is not None:
if main_param.grad is None: if main_param.grad is None:
...@@ -259,7 +259,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -259,7 +259,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
model_data = [] model_data = []
main_data = [] main_data = []
for model_group, main_group in zip(self.fp16_groups, for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups): self.fp32_from_fp16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data) model_data.append(model_param.data)
main_data.append(main_param.data) main_data.append(main_param.data)
...@@ -282,7 +282,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -282,7 +282,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def reload_model_params(self): def reload_model_params(self):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment