Commit e519c1e3 authored by luise.chen's avatar luise.chen Committed by flyingdown
Browse files

Add FusedLARS optimizer (#109)

* Add fused_lars optimizer

* Update primitive fused_lars optimizer, working for resnet50 with NHWC/NCHW

* Add flow of using nesterov in FusedLARS
parent 3d72ea06
...@@ -4,3 +4,4 @@ from .fused_novograd import FusedNovoGrad ...@@ -4,3 +4,4 @@ from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB from .fused_lamb import FusedLAMB
from .fused_adagrad import FusedAdagrad from .fused_adagrad import FusedAdagrad
from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb
from .fused_lars import FusedLARS
import torch
from torch.optim.optimizer import Optimizer, required
from torch import nn
from torch.nn.parameter import Parameter
from apex.multi_tensor_apply import multi_tensor_applier
class FusedLARS(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, trust_coefficient=0.001, eps=0.0,
nesterov=False, wd_after_momentum=False,
materialize_master_grads=True, set_grad_none=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov, trust_coefficient=trust_coefficient, eps=eps, is_skipped=False)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedLARS, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
self.trust_coefficient = trust_coefficient
self.eps = eps
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self.multi_tensor_lars = amp_C.multi_tensor_lars
self._dummy_overflow_buf = torch.cuda.IntTensor(1).zero_()
else:
raise RuntimeError('apex.optimizers.FusedLARS requires cuda extensions')
def __setstate__(self, state):
super(FusedLARS, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedLARS, self).zero_grad()
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
if p.grad is None:
continue
param_state = self.state[p]
d_p = p.grad.data
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
explicit_master_params = False
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']
is_skipped = group['is_skipped']
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
g_norms_grp = []
w_norms_grp = []
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
print('explicit_master_params')
stash = self._amp_stash
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
#fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = []
for p in fp16_params:
if p.is_contiguous():
fp16_grads.append(p.grad)
elif p.is_contiguous(memory_format=torch.channels_last):
fp16_grads.append(p.grad.to(memory_format=torch.channels_last))
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
# Compute L2 norms
if len(fp16_params) > 0:
w_norms = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[p.data for p in fp16_params]],
True)[1]
g_norms = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[p.data for p in fp16_grads]],
True)[1]
else:
w_norms = []
g_norms = []
w_norms_grp.append(w_norms)
g_norms_grp.append(g_norms)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = []
for p in fp32_params:
if p.is_contiguous():
fp32_grads.append(p.grad)
elif p.is_contiguous(memory_format=torch.channels_last):
fp32_grads.append(p.grad.to(memory_format=torch.channels_last))
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
# Compute L2 norms
if len(fp32_params) > 0:
w_norms = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[p.data for p in fp32_params]],
True)[1]
g_norms = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[p.data for p in fp32_grads]],
True)[1]
else:
w_norms = []
g_norms = []
w_norms_grp.append(w_norms)
g_norms_grp.append(g_norms)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for s, (launch_set, first_run, g_norms, w_norms) in enumerate(zip(launch_sets, first_runs, g_norms_grp, w_norms_grp)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_lars,
self._dummy_overflow_buf,
launch_set,
g_norms,
w_norms,
group['lr'],
group['trust_coefficient'],
self.eps,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
self.wd_after_momentum,
1.0/self.most_recent_scale,
group['is_skipped'])
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
...@@ -144,6 +144,24 @@ void multi_tensor_lamb_mp_cuda( ...@@ -144,6 +144,24 @@ void multi_tensor_lamb_mp_cuda(
at::Tensor found_inf, at::Tensor found_inf,
at::Tensor inv_scale); at::Tensor inv_scale);
void multi_tensor_lars_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor grad_norms,
at::Tensor param_norms,
float lr,
float trust_coefficient,
float epsilon,
float weight_decay,
float momentum,
float dampening,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale,
const bool is_skipped);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
...@@ -171,4 +189,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -171,4 +189,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes and apply update for LAMB optimizer"); "Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
"Computes and apply update for LAMB optimizer"); "Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lars", &multi_tensor_lars_cuda,
"Fused LARS optimizer for list of contiguous tensors");
} }
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "type_shim.h"
#include "compat.h"
#include "multi_tensor_apply.cuh"
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template<int N, typename T_grad, typename T_weight>
struct LARSFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<N>& tl,
float *grad_norms,
float *param_norms,
float lr,
float trust_coefficient,
float epsilon,
float weight_decay,
float momentum,
float dampening,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale,
const bool is_skipped) {
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
n -= chunk_idx * chunk_size;
//n = min(n, chunk_size);
T_grad* grad_in = (T_grad*) tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size;
T_weight* weight_in = (T_weight*) tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size;
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr;
if(N == 4)
{
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
float scaled_lr;
if (is_skipped) {
scaled_lr = lr;
}
else {
int tensor_offset = tl.start_tensor_this_launch + tensor_loc;
float p_norm = param_norms[tensor_offset];
float trust_ratio = 1.0;
float g_norm = grad_norms[tensor_offset];
if (g_norm > 0.0f && p_norm > 0.0f) {
trust_ratio = trust_coefficient * p_norm / (g_norm + p_norm * weight_decay + epsilon);
}
scaled_lr = lr * trust_ratio;
}
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
// apply weight decay before momentum
incoming_grads[ii] += weight_decay * incoming_weights[ii];
incoming_moms[ii] = incoming_moms[ii] * momentum - scaled_lr * incoming_grads[ii];
// adjust the weight and write out
if (nesterov) {
incoming_weights[ii] += incoming_moms[ii] * momentum - scaled_lr * incoming_grads[ii];
} else {
incoming_weights[ii] += incoming_moms[ii];
}
weight_in[i] = static_cast<T_weight>(incoming_weights[ii]);
// if necessary, write out an fp16 copy of the weights
if(N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum
//if(momentum != 0.f)
mom_in[i] = static_cast<T_weight>(incoming_moms[ii]);
}
}
}
}
};
void multi_tensor_lars_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor grad_norms,
at::Tensor param_norms,
float lr,
float trust_coefficient,
float epsilon,
float weight_decay,
float momentum,
float dampening,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale,
const bool is_skipped)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if(num_tensors == 4) {
for(int i = 0; i < tensor_lists[3].size(); i++) {
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
}
}
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LARSFunctor<3, at::Half, at::Half>(),
grad_norms.DATA_PTR<float>(),
param_norms.DATA_PTR<float>(),
lr,
trust_coefficient,
epsilon,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
wd_after_momentum,
scale,
is_skipped);
}
// Case 2. fp32, fp32, fp32, No
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LARSFunctor<3, float, float>(),
grad_norms.DATA_PTR<float>(),
param_norms.DATA_PTR<float>(),
lr,
trust_coefficient,
epsilon,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
wd_after_momentum,
scale,
is_skipped);
}
// Case 3. fp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LARSFunctor<4, at::Half, float>(),
grad_norms.DATA_PTR<float>(),
param_norms.DATA_PTR<float>(),
lr,
trust_coefficient,
epsilon,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
wd_after_momentum,
scale,
is_skipped);
}
// Case 4. fp32, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LARSFunctor<4, float, float>(),
grad_norms.DATA_PTR<float>(),
param_norms.DATA_PTR<float>(),
lr,
trust_coefficient,
epsilon,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
wd_after_momentum,
scale,
is_skipped);
}
// Case 5. bfp16, bfp16, bfp16, No
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LARSFunctor<3, at::BFloat16, at::BFloat16>(),
grad_norms.DATA_PTR<float>(),
param_norms.DATA_PTR<float>(),
lr,
trust_coefficient,
epsilon,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
wd_after_momentum,
scale,
is_skipped);
}
// Case 6. bfp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LARSFunctor<4, at::BFloat16, float>(),
grad_norms.DATA_PTR<float>(),
param_norms.DATA_PTR<float>(),
lr,
trust_coefficient,
epsilon,
weight_decay,
momentum,
dampening,
nesterov,
first_run,
wd_after_momentum,
scale,
is_skipped);
}
else
{
AT_ERROR("multi_tensor_lars only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}
...@@ -232,6 +232,7 @@ if "--cuda_ext" in sys.argv: ...@@ -232,6 +232,7 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_adam.cu', 'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_adagrad.cu', 'csrc/multi_tensor_adagrad.cu',
'csrc/multi_tensor_novograd.cu', 'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lars.cu',
'csrc/multi_tensor_lamb.cu', 'csrc/multi_tensor_lamb.cu',
'csrc/multi_tensor_lamb_mp.cu'], 'csrc/multi_tensor_lamb_mp.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment