Commit 8599b854 authored by Deyu Fu's avatar Deyu Fu
Browse files

converge fused_sgd and sgd code(dtype support, fused kernel, wd_after)

parent adad5996
import torch
from torch.optim import Optimizer
from amp_C import multi_tensor_axpby
from apex.multi_tensor_apply import multi_tensor_applier
class SGD(Optimizer):
......@@ -16,6 +15,8 @@ class SGD(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
......@@ -40,7 +41,8 @@ class SGD(Optimizer):
"""
def __init__(self, params, lr=0.1, momentum=0., dampening=0.,
weight_decay=0., nesterov=False):
weight_decay=0., nesterov=False, wd_after_momentum=False,
set_grad_none=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
......@@ -50,6 +52,18 @@ class SGD(Optimizer):
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
self.wd_after_momentum = wd_after_momentum
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_axpby = amp_C.multi_tensor_axpby
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
......@@ -60,10 +74,29 @@ class SGD(Optimizer):
group.setdefault('nesterov', False)
def zero_grad(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.fill_(0.33)
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(SGD, self).zero_grad()
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
......@@ -81,59 +114,69 @@ class SGD(Optimizer):
dampening = group['dampening']
nesterov = group['nesterov']
param_list, grad_list, momentum_list = [], [], []
for p in group['params']:
if p.grad is None:
for group_dtype in [torch.float16, torch.float32]:
grad_list = [p.grad for p in group['params'] if (p.dtype == group_dtype and p.grad is not None)]
if len(grad_list) == 0:
continue
# create lists for multi tensor apply
param_list.append(p.data)
grad_list.append(p.grad.data)
param_list = [p for p in group['params'] if (p.dtype == group_dtype and p.grad is not None)]
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(p.grad.data).detach()
group['init'] = True
else:
buf = param_state['momentum_buffer']
group['init'] = False
momentum_list.append(buf)
if weight_decay != 0:
multi_tensor_applier(
multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf,
[grad_list, param_list, grad_list],
1.,
weight_decay,
2)
if momentum != 0:
if not group['init']:
momentum_list, first_run = self.get_momentums(param_list)
multi_tensor_applier(
multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf,
[momentum_list, grad_list, momentum_list],
self.multi_tensor_sgd,
self._dummy_overflow_buf,
[grad_list, param_list, momentum_list],
weight_decay,
momentum,
1.-dampening,
2)
if nesterov:
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0)
else:
# show how to implement SGD using axpby, without writing new multi_tensor kernel
# only enabled now in no momentum case, since it saves creating momentum for us
# keep momentum != 0 code below for completeness
if weight_decay != 0 and not self.wd_after_momentum:
multi_tensor_applier(
self.multi_tensor_axpby,
self._dummy_overflow_buf,
[grad_list, param_list, grad_list],
1.,
weight_decay,
2)
if momentum != 0: # always False
if not first_run:
multi_tensor_applier(
self.multi_tensor_axpby,
self._dummy_overflow_buf,
[momentum_list, grad_list, momentum_list],
momentum,
1.-dampening,
2)
if nesterov:
multi_tensor_applier(
self.multi_tensor_axpby,
self._dummy_overflow_buf,
[grad_list, momentum_list, grad_list],
1.,
momentum,
2)
else:
grad_list = momentum_list
if weight_decay != 0 and self.wd_after_momentum:
multi_tensor_applier(
self.multi_tensor_axpby,
self._dummy_overflow_buf,
[grad_list, param_list, grad_list],
1.,
weight_decay,
2)
multi_tensor_applier(
multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf,
[grad_list, momentum_list, grad_list],
self.multi_tensor_axpby,
self._dummy_overflow_buf,
[param_list, grad_list, param_list],
1.,
momentum,
-group['lr'],
2)
else:
grad_list = momentum_list
multi_tensor_applier(
multi_tensor_axpby,
torch.cuda.IntTensor([0]),#dummy_overflow_buf,
[param_list, grad_list, param_list],
1.,
-group['lr'],
2)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment