Commit ae921de2 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixing FP16_Optimizer handling of LBFGS

parent d695b68b
......@@ -121,9 +121,8 @@ class FP16_Optimizer(object):
print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
master_params_this_group = []
fp32_from_fp16_params_this_group = []
for param in param_group['params']:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
......@@ -131,7 +130,7 @@ class FP16_Optimizer(object):
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
master_params_this_group.append(master_param)
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
......@@ -141,14 +140,12 @@ class FP16_Optimizer(object):
print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
fp32_params_this_group.append(param)
master_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
param_group['params'] = master_params_this_group
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment