"examples/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "70d7486de5e7a52a4744244d023b866d46252889"
Commit cadad920 authored by Simon Layton's avatar Simon Layton
Browse files

Fused multi-tensor SGD

Initial implementation, all fp32
Tested against torch.optim.sgd
parent 40555b3a
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
class SGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
def __setstate__(self, state):
super(SGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
params = [p for p in group['params'] if p is not None]
grads = [p.grad for p in params]
momentums = []
for p in params:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(p.grad.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
# launch update using multi tensor apply
multi_tensor_applier(
amp_C.multi_tensor_sgd,
self._dummy_overflow_buf,
[grads, params, momentums],
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run)
return loss
...@@ -6,6 +6,17 @@ void multi_tensor_scale_cuda( ...@@ -6,6 +6,17 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run);
void scale_check_overflow_cuda( void scale_check_overflow_cuda(
const at::Tensor& grads, const at::Tensor& grads,
float scale, float scale,
...@@ -37,4 +48,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -37,4 +48,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors"); m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
} }
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
**/
template<typename T>
struct SGDFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorList<3>& tl,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run)
{
__shared__ int noop_smem;
if(threadIdx.x == 0)
noop_smem = *noop_gmem;
__syncthreads();
if(noop_smem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* grad_in = (T*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size;
T* weight_in = (T*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size;
T* mom_in = (T*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) {
// apply weight decay
if (wd != 0.f) {
incoming_grads[ii] += wd * incoming_weights[ii];
}
if (momentum != 0.f) {
if (!first_run) {
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
}
if (nesterov) {
incoming_grads[ii] += momentum * incoming_moms[ii];
}
}
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// also write out the new momentum
if (momentum != 0.f) {
mom_in[i] = incoming_moms[ii];
}
}
}
// *noop_gmem = 1 is NOT guaranteed to be seen immediately by thread 0. I wonder if
// we can rig block-wide and grid-wide short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
if(threadIdx.x == 0)
noop_smem = *noop_gmem;
__syncthreads();
if(noop_smem == 1)
break;
}
}
};
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
...@@ -49,7 +49,8 @@ if "--cuda_ext" in sys.argv: ...@@ -49,7 +49,8 @@ if "--cuda_ext" in sys.argv:
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp', sources=['csrc/amp_C_frontend.cpp',
'csrc/scale_check_overflow_kernel.cu', 'csrc/scale_check_overflow_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu'], 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_sgd_kernel.cu'],
extra_compile_args={'cxx': ['-O3'], extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo', 'nvcc':['-lineinfo',
'-O3', '-O3',
...@@ -73,6 +74,7 @@ if "--cuda_ext" in sys.argv: ...@@ -73,6 +74,7 @@ if "--cuda_ext" in sys.argv:
'nvcc':['-maxrregcount=50', 'nvcc':['-maxrregcount=50',
'-O3', '-O3',
'--use_fast_math'] + version_ge_1_1})) '--use_fast_math'] + version_ge_1_1}))
print(ext_modules)
setup( setup(
name='apex', name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment