"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "0fb07d0e694f9820620aad6714291851604a5812"
Commit ae921de2 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixing FP16_Optimizer handling of LBFGS

parent d695b68b
...@@ -121,9 +121,8 @@ class FP16_Optimizer(object): ...@@ -121,9 +121,8 @@ class FP16_Optimizer(object):
print("FP16_Optimizer processing param group {}:".format(i)) print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = [] fp16_params_this_group = []
fp32_params_this_group = [] fp32_params_this_group = []
master_params_this_group = []
fp32_from_fp16_params_this_group = [] fp32_from_fp16_params_this_group = []
for param in param_group['params']: for i, param in enumerate(param_group['params']):
if param.requires_grad: if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor': if param.type() == 'torch.cuda.HalfTensor':
print("FP16_Optimizer received torch.cuda.HalfTensor with {}" print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
...@@ -131,7 +130,7 @@ class FP16_Optimizer(object): ...@@ -131,7 +130,7 @@ class FP16_Optimizer(object):
fp16_params_this_group.append(param) fp16_params_this_group.append(param)
master_param = param.detach().clone().float() master_param = param.detach().clone().float()
master_param.requires_grad = True master_param.requires_grad = True
master_params_this_group.append(master_param) param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param. # Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32. # We still need to recast per-param state tensors, if any, to FP32.
...@@ -141,14 +140,12 @@ class FP16_Optimizer(object): ...@@ -141,14 +140,12 @@ class FP16_Optimizer(object):
print("FP16_Optimizer received torch.cuda.FloatTensor with {}" print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size())) .format(param.size()))
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
master_params_this_group.append(param) param_group['params'][i] = param
else: else:
raise TypeError("Wrapped parameters must be either " raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
param_group['params'] = master_params_this_group
self.fp16_groups.append(fp16_params_this_group) self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment