Commit 8599b854 authored by Deyu Fu's avatar Deyu Fu
Browse files

converge fused_sgd and sgd code(dtype support, fused kernel, wd_after)

parent adad5996
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from amp_C import multi_tensor_axpby
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
class SGD(Optimizer): class SGD(Optimizer):
...@@ -16,6 +15,8 @@ class SGD(Optimizer): ...@@ -16,6 +15,8 @@ class SGD(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0) dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False) nesterov (bool, optional): enables Nesterov momentum (default: False)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
Example: Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad() >>> optimizer.zero_grad()
...@@ -40,7 +41,8 @@ class SGD(Optimizer): ...@@ -40,7 +41,8 @@ class SGD(Optimizer):
""" """
def __init__(self, params, lr=0.1, momentum=0., dampening=0., def __init__(self, params, lr=0.1, momentum=0., dampening=0.,
weight_decay=0., nesterov=False): weight_decay=0., nesterov=False, wd_after_momentum=False,
set_grad_none=True):
if lr < 0.0: if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
...@@ -50,6 +52,18 @@ class SGD(Optimizer): ...@@ -50,6 +52,18 @@ class SGD(Optimizer):
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov) weight_decay=weight_decay, nesterov=nesterov)
self.wd_after_momentum = wd_after_momentum
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_axpby = amp_C.multi_tensor_axpby
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults) super(SGD, self).__init__(params, defaults)
...@@ -60,10 +74,29 @@ class SGD(Optimizer): ...@@ -60,10 +74,29 @@ class SGD(Optimizer):
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
def zero_grad(self): def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is not None: p.grad = None
p.grad.fill_(0.33) else:
super(SGD, self).zero_grad()
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -81,59 +114,69 @@ class SGD(Optimizer): ...@@ -81,59 +114,69 @@ class SGD(Optimizer):
dampening = group['dampening'] dampening = group['dampening']
nesterov = group['nesterov'] nesterov = group['nesterov']
param_list, grad_list, momentum_list = [], [], [] for group_dtype in [torch.float16, torch.float32]:
grad_list = [p.grad for p in group['params'] if (p.dtype == group_dtype and p.grad is not None)]
for p in group['params']: if len(grad_list) == 0:
if p.grad is None:
continue continue
param_list = [p for p in group['params'] if (p.dtype == group_dtype and p.grad is not None)]
# create lists for multi tensor apply
param_list.append(p.data)
grad_list.append(p.grad.data)
if momentum != 0: if momentum != 0:
param_state = self.state[p] momentum_list, first_run = self.get_momentums(param_list)
if 'momentum_buffer' not in param_state: multi_tensor_applier(
buf = param_state['momentum_buffer'] = torch.clone(p.grad.data).detach() self.multi_tensor_sgd,
group['init'] = True self._dummy_overflow_buf,
[grad_list, param_list, momentum_list],
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0)
else: else:
buf = param_state['momentum_buffer'] # show how to implement SGD using axpby, without writing new multi_tensor kernel
group['init'] = False # only enabled now in no momentum case, since it saves creating momentum for us
momentum_list.append(buf) # keep momentum != 0 code below for completeness
if weight_decay != 0 and not self.wd_after_momentum:
if weight_decay != 0:
multi_tensor_applier( multi_tensor_applier(
multi_tensor_axpby, self.multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf, self._dummy_overflow_buf,
[grad_list, param_list, grad_list], [grad_list, param_list, grad_list],
1., 1.,
weight_decay, weight_decay,
2) 2)
if momentum != 0: if momentum != 0: # always False
if not group['init']: if not first_run:
multi_tensor_applier( multi_tensor_applier(
multi_tensor_axpby, self.multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf, self._dummy_overflow_buf,
[momentum_list, grad_list, momentum_list], [momentum_list, grad_list, momentum_list],
momentum, momentum,
1.-dampening, 1.-dampening,
2) 2)
if nesterov: if nesterov:
multi_tensor_applier( multi_tensor_applier(
multi_tensor_axpby, self.multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf, self._dummy_overflow_buf,
[grad_list, momentum_list, grad_list], [grad_list, momentum_list, grad_list],
1., 1.,
momentum, momentum,
2) 2)
else: else:
grad_list = momentum_list grad_list = momentum_list
if weight_decay != 0 and self.wd_after_momentum:
multi_tensor_applier( multi_tensor_applier(
multi_tensor_axpby, self.multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf, self._dummy_overflow_buf,
[grad_list, param_list, grad_list],
1.,
weight_decay,
2)
multi_tensor_applier(
self.multi_tensor_axpby,
self._dummy_overflow_buf,
[param_list, grad_list, param_list], [param_list, grad_list, param_list],
1., 1.,
-group['lr'], -group['lr'],
2) 2)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment