Unverified Commit 08e88b1b authored by athitten's avatar athitten Committed by GitHub
Browse files

Enable Distributed FusedLAMB (#57)

parent 51b402df
...@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2, at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3, at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction, at::Tensor per_tensor_bias_correction,
const int step, at::Tensor step,
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float grad_scale); at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm);
void multi_tensor_lamb_update_weights_cuda( void multi_tensor_lamb_update_weights_cuda(
int chunk_size, int chunk_size,
...@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda( ...@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate, at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb); bool use_nvlamb);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor ...@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_beta2, const MATH_T* per_tensor_beta2,
const MATH_T* per_tensor_beta3, const MATH_T* per_tensor_beta3,
const int* per_tensor_bias_correction, const int* per_tensor_bias_correction,
const int step, const int* step,
const MATH_T* per_tensor_epsilon, const MATH_T* per_tensor_epsilon,
adamMode_t mode, adamMode_t mode,
const MATH_T* per_tensor_decay, const MATH_T* per_tensor_decay,
const float grad_scale) const MATH_T* global_scale,
const MATH_T* global_grad_norm,
const float max_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) if (*noop_gmem == 1)
// return; return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float combined_scale = *global_scale;
if (max_grad_norm > 0) {
combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);
combined_scale = *global_scale / std::min((float) 1.0, combined_scale);
}
MATH_T beta1 = per_tensor_beta1[tensor_num]; MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num]; MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1; MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction; MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) { if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, step); beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, step); beta2_correction = 1 - pow(beta2, *step);
} else { } else {
beta1_correction = (MATH_T) 1.0; beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0; beta2_correction = (MATH_T) 1.0;
...@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor ...@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor ...@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor ...@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor ...@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor ...@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata<3>& tl, TensorListMetadata<3>& tl,
const MATH_T* per_tensor_param_norm, const MATH_T* per_tensor_param_norm,
const MATH_T* per_tensor_update_norm, const MATH_T* per_tensor_update_norm,
const MATH_T learning_rate, const long* update_norm_offset,
const MATH_T* learning_rate,
const MATH_T* per_tensor_decay, const MATH_T* per_tensor_decay,
const MATH_T* global_grad_norm,
bool use_nvlamb) bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) if (*noop_gmem == 1)
// return; return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl.start_tensor_this_launch + tensor_loc;
...@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor ...@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T decay = per_tensor_decay[tensor_num]; MATH_T decay = per_tensor_decay[tensor_num];
MATH_T ratio = learning_rate; MATH_T ratio = *learning_rate;
// nvlamb: apply adaptive learning rate to all parameters // nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay // otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != (MATH_T) 0.0)) if (use_nvlamb || (decay != (MATH_T) 0.0))
{ {
MATH_T param_norm = per_tensor_param_norm[tensor_num]; MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num]; MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate; ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);
} }
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
...@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor ...@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]); r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
convert(r_p[ii], r_p_copy[ii]); convert(r_p[ii], r_p_copy[ii]);
} }
load_store(p, r_p, i_start, 0); load_store(p, r_p, i_start, 0);
...@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2, at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3, at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction, at::Tensor per_tensor_bias_correction,
const int step, at::Tensor step,
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float grad_scale) at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm)
{ {
using namespace at; using namespace at;
...@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2.DATA_PTR<scalar_t_2>(), per_tensor_beta2.DATA_PTR<scalar_t_2>(),
per_tensor_beta3.DATA_PTR<scalar_t_2>(), per_tensor_beta3.DATA_PTR<scalar_t_2>(),
per_tensor_bias_correction.DATA_PTR<int>(), per_tensor_bias_correction.DATA_PTR<int>(),
step, step.DATA_PTR<int>(),
per_tensor_epsilon.DATA_PTR<scalar_t_2>(), per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode, (adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(), per_tensor_decay.DATA_PTR<scalar_t_2>(),
grad_scale); ))) global_scale.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
max_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
...@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda( ...@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate, at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb) bool use_nvlamb)
{ {
using namespace at; using namespace at;
...@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda( ...@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(), DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_param_norm.DATA_PTR<scalar_t_2>(), per_tensor_param_norm.DATA_PTR<scalar_t_2>(),
per_tensor_update_norm.DATA_PTR<scalar_t_2>(), per_tensor_update_norm.DATA_PTR<scalar_t_2>(),
(scalar_t_2) learning_rate, update_norm_offset.DATA_PTR<long>(),
learning_rate.DATA_PTR<scalar_t_2>(),
per_tensor_decay.DATA_PTR<scalar_t_2>(), per_tensor_decay.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
use_nvlamb); ))) use_nvlamb); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment