Commit dcc7b513 authored by Jeff Daily's avatar Jeff Daily
Browse files

Merge remote-tracking branch 'upstream/master'

Conflicts:
csrc/multi_tensor_apply.cuh
setup.py
tests/L0/run_optimizers/test_adagrad.py
tests/L0/run_optimizers/test_fused_optimizer.py
tests/L0/run_optimizers/test_lamb.py
parents d061bf20 154c6336
...@@ -150,12 +150,12 @@ CUDA and C++ extensions via ...@@ -150,12 +150,12 @@ CUDA and C++ extensions via
``` ```
$ git clone https://github.com/NVIDIA/apex $ git clone https://github.com/NVIDIA/apex
$ cd apex $ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ $ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
``` ```
Apex also supports a Python-only build (required with Pytorch 0.4) via Apex also supports a Python-only build (required with Pytorch 0.4) via
``` ```
$ pip install -v --no-cache-dir ./ $ pip install -v --disable-pip-version-check --no-cache-dir ./
``` ```
A Python-only build omits: A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
......
...@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda( ...@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda( ...@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
} }
......
...@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda( ...@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda( ...@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
} else{ } else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>( dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda( ...@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches*q_seq_len,
heads); heads, stream);
} }
//backward pass is completely in-place //backward pass is completely in-place
......
#pragma once
//Philox CUDA.
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
}
...@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, // torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd( ...@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
...@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd( ...@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd( ...@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
...@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd( ...@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
...@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd( ...@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
output_grads, output_grads,
matmul2_results, matmul2_results,
dropout_results, dropout_results,
softmax_results, bmm1_results,
pad_mask,
input_lin_results, input_lin_results,
inputs, inputs,
input_weights, input_weights,
......
...@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
...@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
beta_zero, beta_zero,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(bmm1_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (is_training) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half*>(dropout_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), (is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
attn_batches*q_seq_len*q_seq_len,
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask, pad_mask,
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
...@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences); attn_batches*q_seq_len/sequences);
} }
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const half*>(v_lin_results_ptr), static_cast<const half*>(v_lin_results_ptr),
lead_dim, lead_dim,
batch_stride, batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
beta_zero, beta_zero,
...@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
return { return {
input_lin_results, input_lin_results,
softmax_results, bmm1_results,
dropout_results, dropout_results,
dropout_mask, dropout_mask,
matmul2_results, matmul2_results,
...@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
......
...@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
......
This diff is collapsed.
#include <torch/extension.h>
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
struct DistAdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float* per_tensor_beta1,
const float* per_tensor_beta2,
const int* per_tensor_bias_correction,
const float* per_tensor_eps,
const float* per_tensor_weight_decay,
const float lr,
const float grad_scale,
const int step,
adamMode_t mode)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float b1 = per_tensor_beta1[tensor_num];
float b2 = per_tensor_beta2[tensor_num];
float eps = per_tensor_eps[tensor_num];
float decay = per_tensor_weight_decay[tensor_num];
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - std::pow(b1, step);
beta2_correction = 1 - std::pow(b2, step);
}
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy)) {
for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = incoming_m[ii] / beta1_correction;
T next_v_unbiased = incoming_v[ii] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
} else {
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if (j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = m[j] / beta1_correction;
T next_v_unbiased = v[j] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
p[j] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode)
{
using namespace at;
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<5, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<4, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
}
THCudaCheck(cudaGetLastError());
}
...@@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float global_grad_norm, const float grad_scale);
const float max_global_grad_norm);
void multi_tensor_lamb_update_weights_cuda( void multi_tensor_lamb_update_weights_cuda(
int chunk_size, int chunk_size,
......
...@@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor ...@@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_epsilon, const MATH_T* per_tensor_epsilon,
adamMode_t mode, adamMode_t mode,
const MATH_T* per_tensor_decay, const MATH_T* per_tensor_decay,
const MATH_T global_grad_norm, const float grad_scale)
const MATH_T max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor ...@@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
MATH_T clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : (MATH_T) 1.0;
MATH_T beta1 = per_tensor_beta1[tensor_num]; MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta1[tensor_num]; MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = per_tensor_beta1[tensor_num]; MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction; MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) { if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, (MATH_T) step); beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, (MATH_T) step); beta2_correction = 1 - pow(beta2, step);
} else { } else {
beta1_correction = (MATH_T) 1.0; beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0; beta2_correction = (MATH_T) 1.0;
...@@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor ...@@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor ...@@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor ...@@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor ...@@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor ...@@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor
{ {
MATH_T param_norm = per_tensor_param_norm[tensor_num]; MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num]; MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != (MATH_T) 0.0 && param_norm != (MATH_T) 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate; ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
} }
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
...@@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float global_grad_norm, const float grad_scale)
const float max_global_grad_norm)
{ {
using namespace at; using namespace at;
...@@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_epsilon.DATA_PTR<scalar_t_2>(), per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode, (adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(), per_tensor_decay.DATA_PTR<scalar_t_2>(),
(scalar_t_2) global_grad_norm, grad_scale); )))
(scalar_t_2) max_global_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
......
...@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
mask_additive_t= torch.tensor([mask_additive])
if use_biases_t[0]: if use_biases_t[0]:
if not mask_additive: if not mask_additive:
...@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \ output_biases, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else: else:
input_lin_results, \ input_lin_results, \
softmax_results, \ bmm1_results, \
dropout_results, \ dropout_results, \
dropout_mask, \ dropout_mask, \
matmul2_results, \ matmul2_results, \
...@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \ output_biases, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
null_tensor, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else: else:
...@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \ ctx.save_for_backward(use_biases_t, \
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \ input_lin_results, \
inputs, \ inputs, \
input_weights, \ input_weights, \
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
...@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \ input_lin_results, \
inputs, \ inputs, \
input_weights, \ input_weights, \
...@@ -101,6 +134,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -101,6 +134,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]: if use_biases_t[0]:
if not mask_additive_t[0]:
input_grads, \ input_grads, \
input_weight_grads, \ input_weight_grads, \
output_weight_grads, \ output_weight_grads, \
...@@ -119,6 +153,26 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -119,6 +153,26 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \ dropout_mask, \
dropout_prob_t[0]) dropout_prob_t[0])
else:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias_additive_mask.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
bmm1_results, \
pad_mask, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else: else:
input_bias_grads = None input_bias_grads = None
output_bias_grads = None output_bias_grads = None
......
...@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
def forward(ctx, use_time_mask, is_training, heads, scale, inputs, def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights, input_weights, output_weights,
input_biases, output_biases, input_biases, output_biases,
mask, dropout_prob): mask, is_additive_mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None]) use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads]) heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
...@@ -60,6 +60,9 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -60,6 +60,9 @@ class SelfAttnFunc(torch.autograd.Function):
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
if is_additive_mask:
matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool) mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf')) matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
......
...@@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0) (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False) weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
...@@ -67,7 +69,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -67,7 +69,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -89,6 +91,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -89,6 +91,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
'max_grad_norm': max_grad_norm, 'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode, 'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb, 'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment, 'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions, 'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size, 'dwu_group_size': dwu_group_size,
...@@ -107,7 +110,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -107,7 +110,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -127,9 +130,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -127,9 +130,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._adam_w_mode = 1 if adam_w_mode else 0 self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm
self._is_accumulation_step = False self._is_accumulation_step = False
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
...@@ -468,9 +473,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -468,9 +473,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2 local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm) l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0]) torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib) return l2_norm.masked_select(self._model_param_is_contrib)
def _pipeline_step(self): def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6)
combined_scale = self.global_scale / min(1, combined_scale)
# Call step kernel once per step # Call step kernel once per step
# Call all-gather once per step # Call all-gather once per step
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
...@@ -478,7 +497,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -478,7 +497,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm() param_norm = self.__compute_contrib_param_norm()
max_grad_norm = self.defaults['max_grad_norm']
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term, multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf, self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
...@@ -490,8 +508,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -490,8 +508,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_epsilon, self._contrib_epsilon,
self._adam_w_mode, self._adam_w_mode,
self._contrib_weight_decay, self._contrib_weight_decay,
self.L2_grad_norm, combined_scale)
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm() upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights, multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf, self._overflow_buf,
...@@ -537,6 +554,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -537,6 +554,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property @property
def L2_grad_norm(self): def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
......
# Introduction to ASP # Introduction to ASP
This page documents the API for ASP (Automatic Sparsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python. This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP ## Importing ASP
``` ```
...@@ -14,8 +14,7 @@ Apart from the import statement, it is sufficient to add just the following line ...@@ -14,8 +14,7 @@ Apart from the import statement, it is sufficient to add just the following line
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
``` ```
In a typical PyTorch training loop, it might look like this: In the context of a typical PyTorch training loop, it might look like this:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
...@@ -28,10 +27,52 @@ for epoch in range(epochs): ...@@ -28,10 +27,52 @@ for epoch in range(epochs):
torch.save(...) torch.save(...)
``` ```
The `prune_trained_model` calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. In order to recompute the sparse mask in between training, say after an epoch, use the following method: The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.
## Generate a Sparse Network
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
```
(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.
(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model
from apex.contrib.sparsity import ASP
ASP.prune_trained_model(model, optimizer) #pruned a trained model
x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
y_pred = model(x)
loss = criterion(y_pred, y)
lr_scheduler.step()
loss.backward()
optimizer.step()
torch.save(...) # saves the pruned checkpoint with sparsity masks
```
## Non-Standard Usage
If your goal is to easily perpare a network for accelerated inference, please follow the recipe above. However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:
``` ```
ASP.compute_sparse_masks() ASP.compute_sparse_masks()
``` ```
A more thorough example can be found in `./test/toy_problem.py`. A more thorough example can be found in `./test/toy_problem.py`.
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_additive_mask(self) :
grads = torch.randn_like(self.tst_inputs)
mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
...@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else: else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
...@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer):
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda') device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists # compute grad norm for two lists
if len(g_all_32) > 0: if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
......
...@@ -76,7 +76,8 @@ class FusedSGD(Optimizer): ...@@ -76,7 +76,8 @@ class FusedSGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0, def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, weight_decay=0, nesterov=False,
wd_after_momentum=False, wd_after_momentum=False,
materialize_master_grads=True): materialize_master_grads=True,
set_grad_none=False):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
...@@ -94,11 +95,12 @@ class FusedSGD(Optimizer): ...@@ -94,11 +95,12 @@ class FusedSGD(Optimizer):
self.materialize_master_grads = materialize_master_grads self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0 self.most_recent_scale = 1.0
self.scale_set_by_backward = False self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = amp_C.multi_tensor_sgd self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else: else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions') raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
...@@ -108,6 +110,14 @@ class FusedSGD(Optimizer): ...@@ -108,6 +110,14 @@ class FusedSGD(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedSGD, self).zero_grad()
def get_momentums(self, params): def get_momentums(self, params):
momentums = [] momentums = []
first_run = True first_run = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment