Unverified Commit 3c8f5161 authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Add fused mixed precision lamb optimizer. (#1237)

* Add fused mixed precision lamb optimizer.

* Fix device usage in constructor.

* Fix sending param_group tensor state to device.

* Remove unneeded device set.
parent aa756cec
......@@ -2,4 +2,5 @@ from .fused_sgd import FusedSGD
from .fused_adam import FusedAdam
from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB
from .fused_adagrad import FusedAdagrad
\ No newline at end of file
from .fused_adagrad import FusedAdagrad
from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb
import torch
from copy import deepcopy
from itertools import chain
from collections import defaultdict, abc as container_abcs
from apex.multi_tensor_apply import multi_tensor_applier
class FusedMixedPrecisionLamb(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, step=0, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False,
reduced_precision_dtype=None):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
# The learning rate (lr) and optimizer step (step) should be located on device
# in order to faciliated device sync free execution
defaults = dict(lr=torch.tensor(lr, dtype=torch.float32),
step=torch.tensor([step], dtype=torch.int),
bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
tensor_state = ['lr', 'step']
super(FusedMixedPrecisionLamb, self).__init__(params, defaults)
device = self.param_groups[0]['params'][0].device
for idx,group in enumerate(self.param_groups):
for item in tensor_state:
self.param_groups[idx][item] = group[item].to(device=device)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm_mp
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb_mp
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
# Mixed Precision support
self.reduced_precision_dtype = reduced_precision_dtype
self.param_groups_full_precision = []
self._step_supports_amp_scaling = True
self.adam_w_mode = 1 if adam_w_mode else 0
self.use_nvlamb = use_nvlamb
# This method is overridden from the parent class because there is not a way to override
# the nested function cast() that copies a saved piece of state to the device without
# redundantly doing the copy.
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# The original version casted the saved value to the params dtype
# This doesn't work for mixed precision Lamb where the momentum and
# velocity are expected to be in full precision while the params are
# in reduced precision
value = value.to(value.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
def _setup_full_precision_params(self):
for i, pg in enumerate(self.param_groups):
param_list = pg['params']
self.param_groups_full_precision.append({
'params': [
p.clone().detach().to(dtype=torch.float32)
if (self.reduced_precision_dtype is not None) and (p.dtype == self.reduced_precision_dtype)
else None
for p in param_list
],
})
# add_param_groups() is overridden because default items can be tensors. The
# parent version does not clone the default item, so two param groups can
# accidentally point to the same default item value where they can differ
# given they are in separate groups.
def add_param_group(self, param_group):
super().add_param_group(param_group)
for name, default in self.defaults.items():
if isinstance(default, torch.Tensor):
self.param_groups[len(self.param_groups) - 1][name] = default.clone()
@torch.no_grad()
def step(self, closure=None, grad_scaler=None):
loss = None
if closure is not None:
loss = closure()
# The full precision params are set up in the first step of the optimizer
# instead of in the constructor because the full precision params will get out
# out of sync with the model params if DDP syncs the model params across devices
# after the optimizer is constructed.
if len(self.param_groups_full_precision) == 0 :
self._setup_full_precision_params()
# create separate grad lists for params
grad_list = []
for gid,group in enumerate(self.param_groups):
for pid,p in enumerate(group['params']):
assert group['params'][0].dtype == p.dtype, \
"Error: Parameters are not of the identical type: {} != {}".format(
group['params'][0].dtype, p.dtype)
if p.grad is None:
continue
grad_list.append(p.grad)
# Overflow check of gradients
device = self.param_groups[0]["params"][0].device
found_inf = (
grad_scaler._check_inf_per_device(self)[device]
if grad_scaler is not None else torch.zeros((1,), device=device)
)
self._dummy_overflow_buf.copy_(found_inf)
# Get unscale scale factor
scale, inv_scale = None, None
if grad_scaler:
scale = grad_scaler._get_scale_async()
inv_scale = scale.double().reciprocal().float()
else:
scale = torch.ones((1,), device=device)
inv_scale = torch.ones((1,), device=device)
# grad_norm is of scaled gradients.
# So, multiply `max_grad_norm` by scale.
max_grad_norm = self.defaults['max_grad_norm'] * scale
grad_norm = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[grad_list],
False,
)[0]
# Run LAMB optimization math
for gid, (group, group_full) in enumerate(zip(self.param_groups, self.param_groups_full_precision)):
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
group['step'] += (self._dummy_overflow_buf != 1).to(torch.int)
state_lists = [ [], # (0) grads
[], # (1) params
[], # (2) momentum state
[], # (3) velocity state
]
if self.reduced_precision_dtype is not None:
state_lists.append([]) # (4) params reduced_dtype
for p, p_full in zip(group['params'], group_full['params']):
if p.grad is None:
continue
assert not p.grad.is_sparse
state = self.state[p]
# State initialization
if len(state) == 0:
dtype = p.dtype
if self.reduced_precision_dtype is not None and p.dtype == self.reduced_precision_dtype :
dtype = torch.float32
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, dtype=dtype)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=dtype)
if self.reduced_precision_dtype is not None :
state_lists[0].append(p.grad.data)
state_lists[1].append(p_full.data)
state_lists[2].append(state['exp_avg'])
state_lists[3].append(state['exp_avg_sq'])
state_lists[4].append(p.data)
else :
state_lists[0].append(p.grad.data)
state_lists[1].append(p.data)
state_lists[2].append(state['exp_avg'])
state_lists[3].append(state['exp_avg_sq'])
multi_tensor_applier(
self.multi_tensor_lamb,
self._dummy_overflow_buf,
state_lists,
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
grad_norm,
max_grad_norm,
self.use_nvlamb,
found_inf,
inv_scale)
return loss
......@@ -33,6 +33,12 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
......@@ -119,6 +125,25 @@ void multi_tensor_lamb_cuda(
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
void multi_tensor_lamb_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
at::Tensor max_grad_norm,
at::optional<bool> use_nvlamb_python,
at::Tensor found_inf,
at::Tensor inv_scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
......@@ -128,6 +153,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
"Computes L2 norm for a list of contiguous tensors and does scaling");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
......@@ -142,4 +169,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
"Computes and apply update for LAMB optimizer");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t>
struct L2NormFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>& tl,
float* output,
float* output_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{
if (*noop_gmem) {
return;
}
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_x[ii]);
vals[ii] += next*next;
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
}
}
float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0)
{
if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
}
}
};
__global__ void cleanup(
float* output,
float* output_per_tensor,
float* ret,
float* ret_per_tensor,
bool per_tensor,
int max_chunks_per_tensor,
volatile int* noop_gmem)
{
if (*noop_gmem) {
return;
}
__shared__ float vals[512];
if(blockIdx.x == 0)
{
float val = 0;
if(threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
if(per_tensor)
{
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
float val = 0;
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python)
{
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if(per_tensor)
{
for(int t = 0; t < ntensors; t++)
{
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
if(max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
}
else
{
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_mp_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
L2NormFunctor<scalar_t_0>(),
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(),
per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr,
per_tensor,
max_chunks_per_tensor, noop_flag.data_ptr<int>());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template<typename T, typename param_t>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta3,
const int* step_ptr,
const int bias_correction,
const float epsilon,
adamMode_t mode,
const float decay,
const float* global_grad_norm,
const float* max_global_grad_norm,
const float* found_inf,
const float* inv_scale)
{
if (*noop_gmem) {
return;
}
float beta1_correction = 1.0f;
float beta2_correction = 1.0f;
if (bias_correction == 1) {
int step = *step_ptr;
beta1_correction = 1 - std::pow(beta1, step);
beta2_correction = 1 - std::pow(beta2, step);
}
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > (*max_global_grad_norm) ? (*global_grad_norm) / (*max_global_grad_norm) : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
param_t* p = (param_t*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
param_t* m = (param_t*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
param_t* v = (param_t*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(g) &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v))
{
T l_g[ILP];
param_t l_p[ILP];
param_t l_m[ILP];
param_t l_v[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_g[ii] = l_g[ii] * (*inv_scale);
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = l_p[ii];
}
r_m[ii] = l_m[ii];
r_v[ii] = l_v[ii];
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
l_p[ii] = r_p[ii];
// Difference from APEX's LAMB kernel. `g` and `p` can be different dtypes.
l_g[ii] = r_p[ii];
l_m[ii] = r_m[ii];
l_v[ii] = r_v[ii];
}
// store
load_store(g, l_g, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
}
}
else
{
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i] * (*inv_scale);
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
// N == 2: FP32 params, no master params
// N == 3: FP16 params, FP32 master params.
template<typename T, int N, typename param_t>
struct LAMBStage2Functor
{
static_assert((N == 2 && std::is_same<T, param_t>::value) || (N == 3 && std::is_same<param_t, float>::value), "");
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<N>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float* learning_rate,
const float decay,
bool use_nvlamb)
{
if (*noop_gmem) {
return;
}
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = *learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? *learning_rate * (param_norm / update_norm) : *learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
param_t* p = (param_t*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* out_p;
if (N == 3) {
out_p = (T*)tl.addresses[2][tensor_loc];
out_p += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
// to make things simple, we put aligned case in a different code path
bool can_use_aligned_path = n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update);
if (N == 3) {
can_use_aligned_path = can_use_aligned_path && is_aligned(out_p);
}
if(can_use_aligned_path)
{
param_t r_p[ILP];
T r_update[ILP];
T r_out_p[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
if (N == 3) {
load_store(r_out_p, out_p, 0, i_start);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
if (N == 3) {
r_out_p[ii] = r_p[ii];
}
}
load_store(p, r_p, i_start, 0);
if (N == 3) {
load_store(out_p, r_out_p, i_start, 0);
}
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
if (N == 3) {
out_p[i] = r_p[ii];
}
}
}
}
}
}
};
void multi_tensor_lamb_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
at::Tensor max_grad_norm,
at::optional<bool> use_nvlamb_python,
at::Tensor found_inf,
at::Tensor inv_scale)
{
// n_tensors == 5: FP16 model params & FP32 master params
// n_tensors == 4: FP32 model params & NO FP32 master params
const auto n_tensors = tensor_lists.size();
assert(n_tensors == 4 || n_tensors == 5);
using namespace at;
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// note(mkozuki): move bias handling below to functor
// Handle bias correction mode
// float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
// if (bias_correction == 1) {
// bias_correction1 = 1 - std::pow(beta1, step);
// bias_correction2 = 1 - std::pow(beta2, step);
// }
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> stage1_tensor_lists(tensor_lists.begin(), tensor_lists.begin() + 4);
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, param_list, true);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
if (n_tensors == 4) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
stage1_tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_0>(),
beta1,
beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
// bias_correction1,
// bias_correction2,
step.data_ptr<int>(),
bias_correction,
epsilon,
(adamMode_t) mode,
weight_decay,
global_grad_norm.data_ptr<float>(),
max_grad_norm.data_ptr<float>(),
found_inf.data_ptr<float>(),
inv_scale.data_ptr<float>()); )
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
stage1_tensor_lists,
LAMBStage1Functor<scalar_t_0, float>(),
beta1,
beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
// bias_correction1,
// bias_correction2,
step.data_ptr<int>(),
bias_correction,
epsilon,
(adamMode_t) mode,
weight_decay,
global_grad_norm.data_ptr<float>(),
max_grad_norm.data_ptr<float>(),
found_inf.data_ptr<float>(),
inv_scale.data_ptr<float>()); )
}
// Compute update norms
auto update_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, grad_list, true);
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
if (n_tensors == 4) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
grad_param_list,
LAMBStage2Functor<scalar_t_0, 2, scalar_t_0>(),
std::get<1>(param_norm_tuple).data_ptr<float>(),
std::get<1>(update_norm_tuple).data_ptr<float>(),
lr.data_ptr<float>(),
weight_decay,
use_nvlamb); )
} else {
grad_param_list.push_back(tensor_lists[4]);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
grad_param_list,
LAMBStage2Functor<scalar_t_0, 3, float>(),
std::get<1>(param_norm_tuple).data_ptr<float>(),
std::get<1>(update_norm_tuple).data_ptr<float>(),
lr.data_ptr<float>(),
weight_decay,
use_nvlamb); )
}
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -165,13 +165,15 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_l2norm_kernel_mp.cu',
'csrc/multi_tensor_l2norm_scale_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_adagrad.cu',
'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'],
'csrc/multi_tensor_lamb.cu',
'csrc/multi_tensor_lamb_mp.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-lineinfo',
'-O3',
......
......@@ -144,14 +144,14 @@ class RefLAMB(Optimizer):
return loss
class TestFusedLAMB(unittest.TestCase):
class TestLamb(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff
self.iters = iters
torch.cuda.manual_seed(9876)
def tearDown(self):
pass
......@@ -162,8 +162,8 @@ class TestFusedLAMB(unittest.TestCase):
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = RefLAMB(ref_param, **lamb_option)
tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)
ref_optim = self.ref_optim(ref_param, **lamb_option)
tst_optim = self.tst_optim(tst_param, use_nvlamb=True, **lamb_option)
return (ref_param, tst_param, ref_optim, tst_optim)
......@@ -211,6 +211,13 @@ class TestFusedLAMB(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedLAMB(TestLamb):
def __init__(self, *args, **kwargs):
super(TestLamb, self).__init__(*args, **kwargs)
self.ref_optim = RefLAMB
self.tst_optim = apex.optimizers.FusedLAMB
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
......@@ -264,6 +271,65 @@ class TestFusedLAMB(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedMixedPrecisionLamb(TestLamb):
def __init__(self, *args, **kwargs):
super(TestLamb, self).__init__(*args, **kwargs)
self.ref_optim = RefLAMB
self.tst_optim = apex.optimizers.FusedMixedPrecisionLamb
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
@unittest.skip("PyTorch optimizer is not numerically correct for fp16")
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_lamb_option(self):
nelem = 1
tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment