Unverified Commit 85b56d01 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

Merge pull request #43 from ROCmSoftwarePlatform/IFU-2021-01-18

IFU-2021-01-18
parents d061bf20 13c8d152
......@@ -150,12 +150,12 @@ CUDA and C++ extensions via
```
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
$ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
Apex also supports a Python-only build (required with Pytorch 0.4) via
```
$ pip install -v --no-cache-dir ./
$ pip install -v --disable-pip-version-check --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
......
......@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads;
}
......
......@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len, stream);
} else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads);
heads, stream);
}
//backward pass is completely in-place
......
#pragma once
//Philox CUDA.
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
}
......@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
// torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
......@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
......@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
......@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
......@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
......@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
output_grads,
matmul2_results,
dropout_results,
softmax_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
......
......@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
......@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
attn_batches*q_seq_len*q_seq_len,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
......@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
......@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
return {
input_lin_results,
softmax_results,
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
......@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
......@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
......
......@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
......
This diff is collapsed.
#include <torch/extension.h>
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
struct DistAdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float* per_tensor_beta1,
const float* per_tensor_beta2,
const int* per_tensor_bias_correction,
const float* per_tensor_eps,
const float* per_tensor_weight_decay,
const float lr,
const float grad_scale,
const int step,
adamMode_t mode)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float b1 = per_tensor_beta1[tensor_num];
float b2 = per_tensor_beta2[tensor_num];
float eps = per_tensor_eps[tensor_num];
float decay = per_tensor_weight_decay[tensor_num];
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - std::pow(b1, step);
beta2_correction = 1 - std::pow(b2, step);
}
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy)) {
for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = incoming_m[ii] / beta1_correction;
T next_v_unbiased = incoming_v[ii] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
} else {
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if (j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = m[j] / beta1_correction;
T next_v_unbiased = v[j] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
p[j] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode)
{
using namespace at;
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<5, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<4, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
}
THCudaCheck(cudaGetLastError());
}
......@@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm);
const float grad_scale);
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
......
......@@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const MATH_T global_grad_norm,
const MATH_T max_global_grad_norm)
const float grad_scale)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : (MATH_T) 1.0;
MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta1[tensor_num];
MATH_T beta3 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, (MATH_T) step);
beta2_correction = 1 - pow(beta2, (MATH_T) step);
beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
......@@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != (MATH_T) 0.0 && param_norm != (MATH_T) 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
......@@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm)
const float grad_scale)
{
using namespace at;
......@@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
(scalar_t_2) global_grad_norm,
(scalar_t_2) max_global_grad_norm); )))
grad_scale); )))
AT_CUDA_CHECK(cudaGetLastError());
}
......
......@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
mask_additive_t= torch.tensor([mask_additive])
if use_biases_t[0]:
if not mask_additive:
......@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else:
input_lin_results, \
softmax_results, \
bmm1_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
......@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
null_tensor, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else:
......@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
......@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
matmul2_results, \
dropout_results, \
softmax_results, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
......@@ -101,24 +134,45 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
if not mask_additive_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias_additive_mask.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
bmm1_results, \
pad_mask, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else:
input_bias_grads = None
output_bias_grads = None
......
......@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, dropout_prob):
mask, is_additive_mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
......@@ -60,8 +60,11 @@ class SelfAttnFunc(torch.autograd.Function):
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
if is_additive_mask:
matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
......
......@@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
......@@ -67,7 +69,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
......@@ -89,6 +91,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
......@@ -107,7 +110,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
......@@ -127,9 +130,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm
self._is_accumulation_step = False
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
......@@ -468,9 +473,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib)
def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6)
combined_scale = self.global_scale / min(1, combined_scale)
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
......@@ -478,7 +497,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm()
max_grad_norm = self.defaults['max_grad_norm']
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
......@@ -490,8 +508,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
self.L2_grad_norm,
max_grad_norm)
combined_scale)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
......@@ -537,6 +554,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
......
# Introduction to ASP
This page documents the API for ASP (Automatic Sparsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP
```
......@@ -14,8 +14,7 @@ Apart from the import statement, it is sufficient to add just the following line
ASP.prune_trained_model(model, optimizer)
```
In a typical PyTorch training loop, it might look like this:
In the context of a typical PyTorch training loop, it might look like this:
```
ASP.prune_trained_model(model, optimizer)
......@@ -28,10 +27,52 @@ for epoch in range(epochs):
torch.save(...)
```
The `prune_trained_model` calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. In order to recompute the sparse mask in between training, say after an epoch, use the following method:
The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.
## Generate a Sparse Network
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
```
(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.
(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model
from apex.contrib.sparsity import ASP
ASP.prune_trained_model(model, optimizer) #pruned a trained model
x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
y_pred = model(x)
loss = criterion(y_pred, y)
lr_scheduler.step()
loss.backward()
optimizer.step()
torch.save(...) # saves the pruned checkpoint with sparsity masks
```
## Non-Standard Usage
If your goal is to easily perpare a network for accelerated inference, please follow the recipe above. However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:
```
ASP.compute_sparse_masks()
```
A more thorough example can be found in `./test/toy_problem.py`.
\ No newline at end of file
A more thorough example can be found in `./test/toy_problem.py`.
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_additive_mask(self) :
grads = torch.randn_like(self.tst_inputs)
mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
......@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer):
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
......@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer):
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
......
......@@ -76,7 +76,8 @@ class FusedSGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
materialize_master_grads=True,
set_grad_none=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
......@@ -94,11 +95,12 @@ class FusedSGD(Optimizer):
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
......@@ -108,6 +110,14 @@ class FusedSGD(Optimizer):
for group in self.param_groups:
group.setdefault('nesterov', False)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedSGD, self).zero_grad()
def get_momentums(self, params):
momentums = []
first_run = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment