Unverified Commit 85b56d01 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

Merge pull request #43 from ROCmSoftwarePlatform/IFU-2021-01-18

IFU-2021-01-18
parents d061bf20 13c8d152
...@@ -150,12 +150,12 @@ CUDA and C++ extensions via ...@@ -150,12 +150,12 @@ CUDA and C++ extensions via
``` ```
$ git clone https://github.com/NVIDIA/apex $ git clone https://github.com/NVIDIA/apex
$ cd apex $ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ $ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
``` ```
Apex also supports a Python-only build (required with Pytorch 0.4) via Apex also supports a Python-only build (required with Pytorch 0.4) via
``` ```
$ pip install -v --no-cache-dir ./ $ pip install -v --disable-pip-version-check --no-cache-dir ./
``` ```
A Python-only build omits: A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
......
...@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda( ...@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda( ...@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
} }
......
...@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda( ...@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda( ...@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
} else{ } else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>( dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda( ...@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches*q_seq_len,
heads); heads, stream);
} }
//backward pass is completely in-place //backward pass is completely in-place
......
#pragma once
//Philox CUDA.
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
}
...@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, // torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd( ...@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
...@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd( ...@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd( ...@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
...@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd( ...@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
...@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd( ...@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
output_grads, output_grads,
matmul2_results, matmul2_results,
dropout_results, dropout_results,
softmax_results, bmm1_results,
pad_mask,
input_lin_results, input_lin_results,
inputs, inputs,
input_weights, input_weights,
......
...@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
...@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
beta_zero, beta_zero,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(bmm1_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (is_training) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half*>(dropout_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), (is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
k_seq_len, reinterpret_cast<const half*>(bmm1_results_ptr),
k_seq_len, pad_mask,
attn_batches*q_seq_len); attn_batches*q_seq_len*q_seq_len,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask, pad_mask,
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
...@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences); attn_batches*q_seq_len/sequences);
} }
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const half*>(v_lin_results_ptr), static_cast<const half*>(v_lin_results_ptr),
lead_dim, lead_dim,
batch_stride, batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
beta_zero, beta_zero,
...@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
return { return {
input_lin_results, input_lin_results,
softmax_results, bmm1_results,
dropout_results, dropout_results,
dropout_mask, dropout_mask,
matmul2_results, matmul2_results,
...@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
......
...@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
......
#pragma once #pragma once
//#include <ATen/ATen.h>
#ifdef OLD_GENERATOR
#include <ATen/CUDAGenerator.h>
#else
#include <ATen/CUDAGeneratorImpl.h>
#endif
//#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
#include "philox.h"
//#include <THC/THCGeneral.h>
#include <assert.h> #include <assert.h>
#include <cfloat> #include <cfloat>
...@@ -17,6 +29,14 @@ namespace { ...@@ -17,6 +29,14 @@ namespace {
template <> template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; } __device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG> template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src); __device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src);
...@@ -24,9 +44,20 @@ namespace { ...@@ -24,9 +44,20 @@ namespace {
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) { __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) {
if (*src == 1) { *dst = value; } if (*src == 1) { *dst = value; }
} }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst, const Datatype *additive_mask);
template <>
__device__ __inline__ void apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
}
template <>
__device__ __inline__ void apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
*(dst+1) += *(additive_mask+1);
*(dst+2) += *(additive_mask+2);
*(dst+3) += *(additive_mask+3);}
} // namespace anonymous } // namespace anonymous
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward // Warp Softmax forward
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -156,6 +187,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc ...@@ -156,6 +187,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
} }
} }
// WARP_BATCH number of batches. // WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data. // WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two. // WARP_SIZE number of elements working on a single batch, has to be a power of two.
...@@ -237,10 +269,428 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i ...@@ -237,10 +269,428 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i
// compute launch size // compute launch size
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block; int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p)
{
assert(ELEMENTS_PER_LDG_STG==4);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
acc_t pinv = acc_t(1)/p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
//vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0;i < WARP_BATCH;++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
Philox ph(seeds.first, tid, seeds.second);
uint8_t rands[WARP_BATCH][WARP_ITERATIONS];
float4 rand_num;
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
rand_num = uniform4(ph());
rands[i][it] = (rand_num.x <= p) > 0.5;
rands[i][it+1] = (rand_num.y <= p) > 0.5;
rands[i][it+2] = (rand_num.z <= p) > 0.5;
rands[i][it+3] = (rand_num.w <= p) > 0.5;
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);
}
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
out[element] = rands[i][it+element] * (pinv * (elements[i][it + element] / sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p)
{
assert(ELEMENTS_PER_LDG_STG==1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
acc_t pinv = acc_t(1)/p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
//vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + local_idx;
src += thread_offset;
dst += thread_offset;
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0;i < WARP_BATCH;++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += 1) {
int element_index = local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < 1;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, 1>(&elements_input[i][it], curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
tid,
seeds.second,
&state);
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += 1) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[1];
acc_t softmax_out[1];
uint8_t dropout_mask_temp[1];
//generate a vector of random numbers here
float rand = curand_uniform(&state);
float *rand_ptr = (float*)(&rand);
#pragma unroll
for (int element = 0;element < 1;++element) {
softmax_out[element] = (elements[i][it + element] / sum[i]);
rand_ptr[element] = rand_ptr[element] <= p;
out[element] = rand_ptr[element] * pinv * softmax_out[element];
dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f
}
copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);
copy_vector<uint8_t, 1>(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp);
}
else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t>
using additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
break;
case 7: // 128
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
//kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
break;
case 8: // 256
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
break;
case 9: // 512
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,16,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
break;
case 10: // 1024
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,32,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
break;
case 11: // 2048
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,64,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,64,32,1>;
break;
default:
return false;
}
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int totalElements, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, cudaStream_t streamid)// p is the probability to keep, not drop
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1);
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
#ifdef OLD_GENERATOR
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
#endif
}
// compute launch size
dim3 threads(warp_size, warps_per_block, 1);
// launch // launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); kernel<<<blocks, threads, 0, streamid>>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride, rng_engine_inputs, p);
return true; return true;
} }
return false; return false;
...@@ -1214,8 +1664,6 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_inp ...@@ -1214,8 +1664,6 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_inp
} }
} }
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax> template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count) __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)
{ {
...@@ -1296,81 +1744,285 @@ __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const in ...@@ -1296,81 +1744,285 @@ __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const in
} }
} }
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count)
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count)
{ {
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. // batch_size might not be a multiple of WARP_BATCH. Check how
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. // there might be multiple batches per warp. compute the index within the batch
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; int local_idx = threadIdx.x % WARP_SIZE;
//vectorize if a row length is multiple of 4
int flag_vec4 = element_count & 3 == 0;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
input_t elements_input[WARP_BATCH][WARP_ITERATIONS] ;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
softmax_input += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const input_t* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
grad_reg[i][it+element] = acc_t(0);
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], softmax_input + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
uint8_t mask_temp[ELEMENTS_PER_LDG_STG];
input_t grad_temp[ELEMENTS_PER_LDG_STG];
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0], mask + itr_idx);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0], grad + itr_idx);
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
grad_reg[i][it+element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale );
}
}
}
}
// load data from global memory
// convert input_t to acc_t
// TODO : remove this, input is already acc_t type in register
acc_t elements[WARP_BATCH][WARP_ITERATIONS] ;
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it ++) {
elements[i][it] = elements[i][it] / sum[i];
grad_reg[i][it] = grad_reg[i][it] * elements[i][it];
}
}
acc_t grad_sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
grad_sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
grad_sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t grad_input_reg[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element=0; element<ELEMENTS_PER_LDG_STG; element++) {
if (is_log_softmax) {
grad_input_reg[element] = (grad_reg[i][it+element] - std::exp(elements[i][it+element]) * grad_sum[i]);
} else {
grad_input_reg[element] = (grad_reg[i][it+element] - elements[i][it+element] * grad_sum[i]);
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
using masked_scale_softmax_warp_backward_recompute_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count);
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
bool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,1,1, is_log_softmax>;
break;
case 1: // 2
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,2,1, is_log_softmax>;
break;
case 2: // 4
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,4,1, is_log_softmax>;
break;
case 3: // 8
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,8,1, is_log_softmax>;
break;
case 4: // 16
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,16,1, is_log_softmax>;
break;
case 5: // 32
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,32,1, is_log_softmax>;
break;
case 6: // 64
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,2,32,1, is_log_softmax>;
break;
case 7: // 128
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,4,32,1, is_log_softmax>;
break;
case 8: // 256
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,1, is_log_softmax>;
break;
case 9: // 512
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,1, is_log_softmax>;
break;
case 10: // 1024
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,1, is_log_softmax>;
break;
case 11: // 2048
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,1, is_log_softmax>;
break;
default:
return false;
}
return true;
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
bool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int pad_batch_stride, int batch_count, cudaStream_t streamid)
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> kernel;
int warp_size, batches_per_warp;
if (!masked_scale_softmax_warp_backward_recompute_kernel<input_t, output_t, acc_t, is_log_softmax>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization // use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block; int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
// compute launch size
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { // launch
case 0: // 1 kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, softmax_elements_stride, pad_batch_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax> return true;
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
} }
return false;
} }
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax> template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid) void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)
{ {
......
#include <torch/extension.h>
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
struct DistAdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float* per_tensor_beta1,
const float* per_tensor_beta2,
const int* per_tensor_bias_correction,
const float* per_tensor_eps,
const float* per_tensor_weight_decay,
const float lr,
const float grad_scale,
const int step,
adamMode_t mode)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float b1 = per_tensor_beta1[tensor_num];
float b2 = per_tensor_beta2[tensor_num];
float eps = per_tensor_eps[tensor_num];
float decay = per_tensor_weight_decay[tensor_num];
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - std::pow(b1, step);
beta2_correction = 1 - std::pow(b2, step);
}
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy)) {
for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = incoming_m[ii] / beta1_correction;
T next_v_unbiased = incoming_v[ii] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
} else {
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if (j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = m[j] / beta1_correction;
T next_v_unbiased = v[j] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
p[j] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode)
{
using namespace at;
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<5, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<4, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
}
THCudaCheck(cudaGetLastError());
}
...@@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float global_grad_norm, const float grad_scale);
const float max_global_grad_norm);
void multi_tensor_lamb_update_weights_cuda( void multi_tensor_lamb_update_weights_cuda(
int chunk_size, int chunk_size,
......
...@@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor ...@@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_epsilon, const MATH_T* per_tensor_epsilon,
adamMode_t mode, adamMode_t mode,
const MATH_T* per_tensor_decay, const MATH_T* per_tensor_decay,
const MATH_T global_grad_norm, const float grad_scale)
const MATH_T max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor ...@@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
MATH_T clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : (MATH_T) 1.0;
MATH_T beta1 = per_tensor_beta1[tensor_num]; MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta1[tensor_num]; MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = per_tensor_beta1[tensor_num]; MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction; MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) { if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, (MATH_T) step); beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, (MATH_T) step); beta2_correction = 1 - pow(beta2, step);
} else { } else {
beta1_correction = (MATH_T) 1.0; beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0; beta2_correction = (MATH_T) 1.0;
...@@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor ...@@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor ...@@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor ...@@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor ...@@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor ...@@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor
{ {
MATH_T param_norm = per_tensor_param_norm[tensor_num]; MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num]; MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != (MATH_T) 0.0 && param_norm != (MATH_T) 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate; ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
} }
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
...@@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float global_grad_norm, const float grad_scale)
const float max_global_grad_norm)
{ {
using namespace at; using namespace at;
...@@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_epsilon.DATA_PTR<scalar_t_2>(), per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode, (adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(), per_tensor_decay.DATA_PTR<scalar_t_2>(),
(scalar_t_2) global_grad_norm, grad_scale); )))
(scalar_t_2) max_global_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
......
...@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
mask_additive_t= torch.tensor([mask_additive])
if use_biases_t[0]: if use_biases_t[0]:
if not mask_additive: if not mask_additive:
...@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \ output_biases, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else: else:
input_lin_results, \ input_lin_results, \
softmax_results, \ bmm1_results, \
dropout_results, \ dropout_results, \
dropout_mask, \ dropout_mask, \
matmul2_results, \ matmul2_results, \
...@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \ output_biases, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
null_tensor, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else: else:
...@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
ctx.save_for_backward(use_biases_t, \ heads_t, \
heads_t, \ matmul2_results, \
matmul2_results, \ dropout_results, \
dropout_results, \ softmax_results, \
softmax_results, \ null_tensor, \
input_lin_results, \ null_tensor, \
inputs, \ mask_additive_t, \
input_weights, \ input_lin_results, \
output_weights, \ inputs, \
dropout_mask, \ input_weights, \
dropout_prob_t) output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
...@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \ input_lin_results, \
inputs, \ inputs, \
input_weights, \ input_weights, \
...@@ -101,24 +134,45 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -101,24 +134,45 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]: if use_biases_t[0]:
input_grads, \ if not mask_additive_t[0]:
input_weight_grads, \ input_grads, \
output_weight_grads, \ input_weight_grads, \
input_bias_grads, \ output_weight_grads, \
output_bias_grads = \ input_bias_grads, \
fast_self_multihead_attn_bias.backward( \ output_bias_grads = \
heads_t[0], \ fast_self_multihead_attn_bias.backward( \
output_grads, \ heads_t[0], \
matmul2_results, \ output_grads, \
dropout_results, \ matmul2_results, \
softmax_results, \ dropout_results, \
input_lin_results, \ softmax_results, \
inputs, \ input_lin_results, \
input_weights, \ inputs, \
output_weights, \ input_weights, \
dropout_mask, \ output_weights, \
dropout_prob_t[0]) dropout_mask, \
dropout_prob_t[0])
else:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias_additive_mask.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
bmm1_results, \
pad_mask, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else: else:
input_bias_grads = None input_bias_grads = None
output_bias_grads = None output_bias_grads = None
......
...@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
def forward(ctx, use_time_mask, is_training, heads, scale, inputs, def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights, input_weights, output_weights,
input_biases, output_biases, input_biases, output_biases,
mask, dropout_prob): mask, is_additive_mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None]) use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads]) heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
...@@ -60,8 +60,11 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -60,8 +60,11 @@ class SelfAttnFunc(torch.autograd.Function):
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool) if is_additive_mask:
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf')) matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1) softmax_results = F.softmax(matmul1_results, dim=-1)
......
...@@ -4,13 +4,15 @@ import importlib ...@@ -4,13 +4,15 @@ import importlib
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import torch.distributed.distributed_c10d as c10d
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_. It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
...@@ -19,20 +21,30 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -19,20 +21,30 @@ class DistributedFusedAdam(torch.optim.Optimizer):
running averages of gradient and its square. (default: (0.9, 0.999)) running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8) numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step, eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
latency. (default: False) amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
overlap_reductions(boolean, optional): whether to overlap reductions overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True) with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be step_supports_amp_scaling(boolean, optional): whether to use customized
reduced during first fp16 gradient reduction block. gradient unscaling logic (default: True)
num_process_groups (integer, optional): number of process groups in
the app (default: 1)
current_process_group (object, optional): the process group to work on
(default: None)
process_group_id (integer, optional): process group id (default: 0)
process_group_size (integer, optional): size of process group
(default: 0)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
model_parallel (boolean, optional): whether model parallelism is used
(default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -41,22 +53,28 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -41,22 +53,28 @@ class DistributedFusedAdam(torch.optim.Optimizer):
""" """
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction=True, betas=(0.9, 0.999),
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, eps=1e-8, eps_inside_sqrt=False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False, weight_decay=0., max_grad_norm=0.,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True, amsgrad=False, flat_mt=False,
compute_L2_grad_norm=False, distributed_weight_update=0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, compute_L2_grad_norm=False,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
do_not_flatten_model=False): predivide=True, e5m2_allgather=False,
global fused_adam_cuda do_not_flatten_model=False,
step_supports_amp_scaling=True,
num_process_groups=1,
current_process_group=None,
process_group_id=0,
process_group_size=0,
clip_grad_norm=True,
model_parallel=False):
global fused_adam_cuda, distributed_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_adam_cuda = importlib.import_module("distributed_adam_cuda")
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad: if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
...@@ -64,21 +82,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -64,21 +82,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(DistributedFusedAdam, self).__init__(params, defaults) super(DistributedFusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
# Misc
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False self._has_overflow = False
self._step_supports_amp_scaling = step_supports_amp_scaling
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None self._global_scale = None
...@@ -87,33 +96,64 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -87,33 +96,64 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._predivide = predivide self._predivide = predivide
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None self._L2_grad_norm = None
self._flat_mt = flat_mt
self._init_done = False
self._resume_from_checkpoint = False
self._step = 0
# Process group related
self._clip_grad_norm = clip_grad_norm
self._model_parallel = model_parallel
self._num_process_groups = num_process_groups
self._current_process_group = current_process_group if current_process_group is not None else c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._process_group_id = process_group_id
self._process_group_size = torch.cuda.device_count() if process_group_size <= 0 else process_group_size
self._world_size = self._process_group_size # world: the current process group
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size self._global_rank = torch.distributed.get_rank()
self._world_rank = self._global_rank // self._num_process_groups
self._group_rank = self._world_rank % self._group_size
#print("world_size:", self._world_size, ", group_size:", self._group_size, ", num_groups:", self._num_groups, ", global_rank:", self._global_rank, ", world_rank:", self._world_rank, ", group_rank:", self._group_rank)
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
def _first_step_init(self):
p_offset = 0 p_offset = 0
p_i = 0 p_i = 0
self._param_state = None
self._model_params = [] self._model_params = []
self._grads_info = [] self._grads_info = []
self._grad_accs = [] self._grad_accs = []
self._group_properties = []
for group in self.param_groups: for group in self.param_groups:
self._param_group = group self._param_group = group
prev = None prev = None
beta1, beta2 = group['betas']
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p,0) # broadcast from rank 0 of current process group
torch.distributed.broadcast(p, src=self._available_ranks[0], group=self._current_process_group)
if not p.requires_grad: if not p.requires_grad:
continue continue
self._model_params.append(p) self._model_params.append(p)
state = self.state[p] # Multiple param groups support:
if len(state) == 0: # store one hyperparam item per parameter tensor
state['step'] = 0 self._group_properties.append((
if self._param_state is None: beta1,
self._param_state = state beta2,
bias_correction,
eps,
weight_decay
))
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset): def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param) param_tmp = param.expand_as(param)
...@@ -133,7 +173,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -133,7 +173,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
prev = p prev = p
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = [] self._grads = []
if self._overlap_reductions: if self._overlap_reductions:
self._current_block = self._num_blocks self._current_block = self._num_blocks
...@@ -145,7 +184,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -145,7 +184,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size)) #print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1): for block_id in range(self._num_blocks-1,-1,-1):
...@@ -153,14 +192,16 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -153,14 +192,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size: while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1 p_i -= 1
self._low_param_i[block_id] = p_i self._low_param_i[block_id] = p_i
print(self._low_param_i) #print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') # initialize master weights, moments buffers if not loaded from checkpoint
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') if self._fp32_p is None:
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16 # FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
...@@ -213,12 +254,15 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -213,12 +254,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# 1) Copy model parameters into master buffer # 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather # 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = [] self._packed_flat_to_model_params = []
self._contrib_tensor_list = []
self._contrib_group_properties = []
self._non_parallel_grads = []
for shard_id in range(self._group_size): for shard_id in range(self._group_size):
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info): for (p, grads_info, group_props) in zip(self._model_params, self._grads_info, self._group_properties):
flat_grad_start = grads_info["param_offset"] flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"] flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start) clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
...@@ -230,60 +274,90 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -230,60 +274,90 @@ class DistributedFusedAdam(torch.optim.Optimizer):
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length] model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) ) self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group: if shard_id == self._group_rank:
# copy model parameters into master buffer # copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size()))) opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
master_param_fragment.copy_(model_param_fragment) opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, g, p_copy
if self._model_parallel and hasattr(p, 'model_parallel') and not p.model_parallel:
self._non_parallel_grads.append(opti_state_g_fragment)
p, m, v, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_tensor_list = [p, m, v, g, p_copy]
math_type = self._fp32_p.dtype
beta1, beta2, bias_correction, epsilon, decay = list(zip(*self._contrib_group_properties))
self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
p_in, p_out = zip(*self._packed_flat_to_model_params) p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out] self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1: if self._num_groups > 1:
self._ar_pg = [] self._ar_pg = []
for dev_i in range(self._group_size): for i in range(self._num_process_groups):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)] # gather global ranks of all members of the current process group
for i in range(self._num_ar_pg): ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
grp = torch.distributed.new_group(ranks=ranks) for j in range(self._group_size):
if torch.distributed.get_rank() in ranks: ar_idx = [j+k*self._group_size for k in range(self._num_groups)]
self._ar_pg.append(grp) ar_rank = [ranks[k] for k in ar_idx]
#if self._global_rank in ar_rank:
# print("group for all reduce, ranks:", ar_rank)
for _ in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ar_rank)
if self._global_rank in ar_rank:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg: for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg) torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups): self._rs_pg, rs_ranks = [],[]
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) for i in range(self._num_process_groups):
self._rs_pg = [] ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for group_i in range(self._num_groups): for j in range(self._num_groups):
ranks = rs_ranks[group_i] rs_idx = [j*self._group_size+k for k in range(self._group_size)]
for i in range(self._num_rs_pg): rs_rank = [ranks[k] for k in rs_idx]
grp = torch.distributed.new_group(ranks=ranks) #if self._global_rank in rs_rank:
if torch.distributed.get_rank() in ranks: # print("group for reduce scatter, ranks:", rs_rank)
self._rs_pg.append(grp) for _ in range(self._num_rs_pg):
if self._compute_L2_grad_norm: grp = torch.distributed.new_group(ranks=rs_rank)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) if self._global_rank in rs_rank:
if torch.distributed.get_rank() in ranks: self._rs_pg.append(grp)
self._l2_grad_norm_pg = l2_grad_norm_pg if self._compute_L2_grad_norm:
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg) l2_grad_norm_pg = torch.distributed.new_group(ranks=rs_rank)
if self._global_rank in rs_rank:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg: for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg) torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0: if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg self._ag_pg = self._rs_pg
self._ag_st = self._rs_st self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg self._num_ag_pg = self._num_rs_pg
else: else:
self._ag_pg = [] self._ag_pg = []
for group_i in range(self._num_groups): for i in range(self._num_process_groups):
ranks = rs_ranks[group_i] ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for i in range(self._num_ag_pg): for j in range(self._num_groups):
grp = torch.distributed.new_group(ranks=ranks) ag_rank = rs_ranks[j]
if torch.distributed.get_rank() in ranks: #if self._global_rank in ag_rank:
self._ag_pg.append(grp) # print("group for all gather, ranks:", ag_rank)
for _ in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ag_rank)
if self._global_rank in ag_rank:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg: for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg) torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
...@@ -296,6 +370,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -296,6 +370,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
import inspect import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def _init_everything(self):
if not self._init_done:
self._first_step_init()
self._init_done = True
def set_last_step(self, last_step): def set_last_step(self, last_step):
self._last_step = last_step self._last_step = last_step
...@@ -350,46 +428,43 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -350,46 +428,43 @@ class DistributedFusedAdam(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
# for model_parallel_rank=0, keep all gradients
# for the rest, subtract non_parallel gradients
if self._model_parallel and self._process_group_id: # non zero model_parallel_rank
non_parallel_grad_norm_sq = torch.zeros([1], device='cuda')
if len(self._non_parallel_grads): # non parallel grads exit
non_parallel_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm,
self._overflow_buf,
[self._non_parallel_grads], False)[0]**2
torch.distributed.all_reduce(non_parallel_grad_norm_sq, group=self._l2_grad_norm_pg)
l2_grad_norm_sq = l2_grad_norm_sq - non_parallel_grad_norm_sq
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item() self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g): def __launch_step_kernel(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self._global_scale combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm): if self._clip_grad_norm and self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6) combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale) combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas'] self._step += 1
fused_adam_cuda.reversible_adam( multi_tensor_applier(distributed_adam_cuda.multi_tensor_fused_adam,
p, p_copy, m, v, g, self._overflow_buf,
self._contrib_tensor_list, # p, m, v, g, p_copy
self._contrib_beta1,
self._contrib_beta2,
self._contrib_bias_correction,
self._contrib_epsilon,
self._contrib_weight_decay,
self._param_group['lr'], self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale, combined_scale,
self._param_state['step']+1, self._step,
self.eps_mode, self.eps_mode)
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self): def _pipeline_step(self):
# Call step kernel once per step # Call step kernel once per step
...@@ -398,12 +473,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -398,12 +473,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel( self.__launch_step_kernel()
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
...@@ -429,8 +499,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -429,8 +499,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
while flush_block: while flush_block:
block_id = flush_block[0] // self._block_size block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale): def set_global_scale(self, global_scale):
...@@ -484,7 +552,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -484,7 +552,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def complete_reductions(self): def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed. """Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._init_everything()
if self._last_step: if self._last_step:
# zero out gradients that have not been completed yet # zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated): for param_i, grad_generated in enumerate(self._grads_generated):
...@@ -506,53 +574,19 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -506,53 +574,19 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._current_block = self._num_blocks self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
def revert_step(self): def step(self, closure=None):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline: self._pipeline_step()
self._pipeline_step()
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Check for overflow # Copy self._new_params to model params
# Store state for loss scaler calculation multi_tensor_applier(
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size) fused_adam_cuda.maybe_cast_mt,
if has_overflow: self._overflow_buf,
self.revert_step() self._packed_flat_to_model_params)
else:
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st) torch.cuda.current_stream().wait_stream(self._completion_st)
...@@ -561,4 +595,42 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -561,4 +595,42 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return loss return loss
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict = {}
state_dict['step'] = self._step
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m
state_dict['fp32_v'] = self._fp32_v
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
...@@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0) (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False) weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
...@@ -67,7 +69,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -67,7 +69,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -89,6 +91,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -89,6 +91,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
'max_grad_norm': max_grad_norm, 'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode, 'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb, 'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment, 'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions, 'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size, 'dwu_group_size': dwu_group_size,
...@@ -107,7 +110,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -107,7 +110,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -127,9 +130,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -127,9 +130,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._adam_w_mode = 1 if adam_w_mode else 0 self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm
self._is_accumulation_step = False self._is_accumulation_step = False
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
...@@ -468,9 +473,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -468,9 +473,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2 local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm) l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0]) torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib) return l2_norm.masked_select(self._model_param_is_contrib)
def _pipeline_step(self): def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6)
combined_scale = self.global_scale / min(1, combined_scale)
# Call step kernel once per step # Call step kernel once per step
# Call all-gather once per step # Call all-gather once per step
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
...@@ -478,7 +497,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -478,7 +497,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm() param_norm = self.__compute_contrib_param_norm()
max_grad_norm = self.defaults['max_grad_norm']
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term, multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf, self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
...@@ -490,8 +508,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -490,8 +508,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_epsilon, self._contrib_epsilon,
self._adam_w_mode, self._adam_w_mode,
self._contrib_weight_decay, self._contrib_weight_decay,
self.L2_grad_norm, combined_scale)
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm() upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights, multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf, self._overflow_buf,
...@@ -537,6 +554,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -537,6 +554,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
......
# Introduction to ASP # Introduction to ASP
This page documents the API for ASP (Automatic Sparsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python. This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP ## Importing ASP
``` ```
...@@ -14,8 +14,7 @@ Apart from the import statement, it is sufficient to add just the following line ...@@ -14,8 +14,7 @@ Apart from the import statement, it is sufficient to add just the following line
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
``` ```
In a typical PyTorch training loop, it might look like this: In the context of a typical PyTorch training loop, it might look like this:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
...@@ -28,10 +27,52 @@ for epoch in range(epochs): ...@@ -28,10 +27,52 @@ for epoch in range(epochs):
torch.save(...) torch.save(...)
``` ```
The `prune_trained_model` calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. In order to recompute the sparse mask in between training, say after an epoch, use the following method: The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.
## Generate a Sparse Network
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
```
(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.
(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model
from apex.contrib.sparsity import ASP
ASP.prune_trained_model(model, optimizer) #pruned a trained model
x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
y_pred = model(x)
loss = criterion(y_pred, y)
lr_scheduler.step()
loss.backward()
optimizer.step()
torch.save(...) # saves the pruned checkpoint with sparsity masks
```
## Non-Standard Usage
If your goal is to easily perpare a network for accelerated inference, please follow the recipe above. However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:
``` ```
ASP.compute_sparse_masks() ASP.compute_sparse_masks()
``` ```
A more thorough example can be found in `./test/toy_problem.py`. A more thorough example can be found in `./test/toy_problem.py`.
\ No newline at end of file
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_additive_mask(self) :
grads = torch.randn_like(self.tst_inputs)
mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
...@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else: else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
...@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer):
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda') device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists # compute grad norm for two lists
if len(g_all_32) > 0: if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
......
...@@ -76,7 +76,8 @@ class FusedSGD(Optimizer): ...@@ -76,7 +76,8 @@ class FusedSGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0, def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, weight_decay=0, nesterov=False,
wd_after_momentum=False, wd_after_momentum=False,
materialize_master_grads=True): materialize_master_grads=True,
set_grad_none=False):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
...@@ -94,11 +95,12 @@ class FusedSGD(Optimizer): ...@@ -94,11 +95,12 @@ class FusedSGD(Optimizer):
self.materialize_master_grads = materialize_master_grads self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0 self.most_recent_scale = 1.0
self.scale_set_by_backward = False self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = amp_C.multi_tensor_sgd self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else: else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions') raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
...@@ -108,6 +110,14 @@ class FusedSGD(Optimizer): ...@@ -108,6 +110,14 @@ class FusedSGD(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedSGD, self).zero_grad()
def get_momentums(self, params): def get_momentums(self, params):
momentums = [] momentums = []
first_run = True first_run = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment