"examples/vscode:/vscode.git/clone" did not exist on "9a0511c8e91a7f633c9c3292fccbcbad5281d1f5"
Commit 4d9dcb57 authored by Deyu Fu's avatar Deyu Fu
Browse files

address comments

parent be42aad5
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import ctypes
lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float
def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough
return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata,
input._cdata, 16384)
else:
return input.norm()
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
...@@ -115,7 +128,8 @@ class FP16_Optimizer(object): ...@@ -115,7 +128,8 @@ class FP16_Optimizer(object):
Returns -1 if the most recently computed fp16 gradients overflowed Returns -1 if the most recently computed fp16 gradients overflowed
""" """
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync # TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync
norm = float(torch.norm(fp16_grads_flat, p=norm_type)) # only support 2-norm now
norm = float(fused_norm(fp16_grads_flat))
if norm == float('inf') or norm == -float('inf') or norm != norm: if norm == float('inf') or norm == -float('inf') or norm != norm:
return -1 return -1
else: else:
...@@ -140,8 +154,8 @@ class FP16_Optimizer(object): ...@@ -140,8 +154,8 @@ class FP16_Optimizer(object):
return return
# norm is in fact norm*cur_scale # norm is in fact norm*cur_scale
self.optimizer.step(grads_group=[[g] for g in grads_groups_flat], self.optimizer.step(grads=[[g] for g in grads_groups_flat],
output_params_group=[[p] for p in self.fp16_groups_flat], output_params=[[p] for p in self.fp16_groups_flat],
scale=self.cur_scale, scale=self.cur_scale,
grad_norms=norm_groups) grad_norms=norm_groups)
......
...@@ -65,7 +65,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -65,7 +65,7 @@ class FusedAdam(torch.optim.Optimizer):
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
def step(self, closure=None, grads_group=None, output_params_group=None, scale=1., grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -84,18 +84,30 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -84,18 +84,30 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if grads_group is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
if output_params_group is None: # backward compatibility
# assuming a list of parameter means single group
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None:
output_params_group = [None]*len(self.param_groups) output_params_group = [None]*len(self.param_groups)
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
output_params_group = output_params
if grad_norms is None: if grad_norms is None:
grad_norms = [None]*len(self.param_groups) grad_norms = [None]*len(self.param_groups)
for group, grads, output_params, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms): for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
if grads is None: if grads_this_group is None:
grads = [None]*len(group['params']) grads_this_group = [None]*len(group['params'])
if output_params is None: if output_params_this_group is None:
output_params = [None]*len(group['params']) output_params_this_group = [None]*len(group['params'])
# compute combined scale factor for this group # compute combined scale factor for this group
combined_scale = scale combined_scale = scale
...@@ -105,7 +117,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -105,7 +117,7 @@ class FusedAdam(torch.optim.Optimizer):
if clip > 1: if clip > 1:
combined_scale = clip * scale combined_scale = clip * scale
for p, grad, output_param in zip(group['params'], grads, output_params): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment