Commit f3868524 authored by Abhishree's avatar Abhishree
Browse files

Enable Distributed FusedLAMB

parent abb6e5ba
......@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
const int step,
at::Tensor step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float grad_scale);
at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm);
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
......@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_beta2,
const MATH_T* per_tensor_beta3,
const int* per_tensor_bias_correction,
const int step,
const int* step,
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const float grad_scale)
const MATH_T* global_scale,
const MATH_T* global_grad_norm,
const float max_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
if (*noop_gmem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float combined_scale = *global_scale;
if (max_grad_norm > 0) {
combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);
combined_scale = *global_scale / std::min((float) 1.0, combined_scale);
}
MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, step);
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
......@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / grad_scale;
MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata<3>& tl,
const MATH_T* per_tensor_param_norm,
const MATH_T* per_tensor_update_norm,
const MATH_T learning_rate,
const long* update_norm_offset,
const MATH_T* learning_rate,
const MATH_T* per_tensor_decay,
const MATH_T* global_grad_norm,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
if (*noop_gmem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
......@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T decay = per_tensor_decay[tensor_num];
MATH_T ratio = learning_rate;
MATH_T ratio = *learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != (MATH_T) 0.0))
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);
}
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
......@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
convert(r_p[ii], r_p_copy[ii]);
}
load_store(p, r_p, i_start, 0);
......@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
const int step,
at::Tensor step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float grad_scale)
at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm)
{
using namespace at;
......@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2.DATA_PTR<scalar_t_2>(),
per_tensor_beta3.DATA_PTR<scalar_t_2>(),
per_tensor_bias_correction.DATA_PTR<int>(),
step,
step.DATA_PTR<int>(),
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
grad_scale); )))
global_scale.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
max_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb)
{
using namespace at;
......@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_param_norm.DATA_PTR<scalar_t_2>(),
per_tensor_update_norm.DATA_PTR<scalar_t_2>(),
(scalar_t_2) learning_rate,
update_norm_offset.DATA_PTR<long>(),
learning_rate.DATA_PTR<scalar_t_2>(),
per_tensor_decay.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
use_nvlamb); )))
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment