Unverified Commit 3fe10b55 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

Seryilmaz/fused dropout softmax (#985)

* fuse dropout into softmax in fprop for additive mask case
parent 6c186b3b
......@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads;
}
......
......@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len, stream);
} else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads);
heads, stream);
}
//backward pass is completely in-place
......
#pragma once
//Philox CUDA.
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
}
......@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
// torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
......@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
......@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
......@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
......@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
......@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
output_grads,
matmul2_results,
dropout_results,
softmax_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
......
......@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
......@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
attn_batches*q_seq_len*q_seq_len,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
......@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
......@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
return {
input_lin_results,
softmax_results,
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
......@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
......@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
......
......@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
......@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
......
#pragma once
//#include <ATen/ATen.h>
#ifdef OLD_GENERATOR
#include <ATen/CUDAGenerator.h>
#else
#include <ATen/CUDAGeneratorImpl.h>
#endif
//#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
#include "philox.h"
//#include <THC/THCGeneral.h>
#include <assert.h>
#include <cfloat>
......@@ -17,6 +29,14 @@ namespace {
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src);
......@@ -24,9 +44,20 @@ namespace {
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) {
if (*src == 1) { *dst = value; }
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst, const Datatype *additive_mask);
template <>
__device__ __inline__ void apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
}
template <>
__device__ __inline__ void apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
*(dst+1) += *(additive_mask+1);
*(dst+2) += *(additive_mask+2);
*(dst+3) += *(additive_mask+3);}
} // namespace anonymous
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -156,6 +187,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
......@@ -237,10 +269,428 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p)
{
assert(ELEMENTS_PER_LDG_STG==4);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
acc_t pinv = acc_t(1)/p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
//vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0;i < WARP_BATCH;++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
Philox ph(seeds.first, tid, seeds.second);
uint8_t rands[WARP_BATCH][WARP_ITERATIONS];
float4 rand_num;
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
rand_num = uniform4(ph());
rands[i][it] = (rand_num.x <= p) > 0.5;
rands[i][it+1] = (rand_num.y <= p) > 0.5;
rands[i][it+2] = (rand_num.z <= p) > 0.5;
rands[i][it+3] = (rand_num.w <= p) > 0.5;
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);
}
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
out[element] = rands[i][it+element] * (pinv * (elements[i][it + element] / sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p)
{
assert(ELEMENTS_PER_LDG_STG==1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
acc_t pinv = acc_t(1)/p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
//vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + local_idx;
src += thread_offset;
dst += thread_offset;
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0;i < WARP_BATCH;++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += 1) {
int element_index = local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < 1;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, 1>(&elements_input[i][it], curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
tid,
seeds.second,
&state);
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += 1) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[1];
acc_t softmax_out[1];
uint8_t dropout_mask_temp[1];
//generate a vector of random numbers here
float rand = curand_uniform(&state);
float *rand_ptr = (float*)(&rand);
#pragma unroll
for (int element = 0;element < 1;++element) {
softmax_out[element] = (elements[i][it + element] / sum[i]);
rand_ptr[element] = rand_ptr[element] <= p;
out[element] = rand_ptr[element] * pinv * softmax_out[element];
dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f
}
copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);
copy_vector<uint8_t, 1>(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp);
}
else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t>
using additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
break;
case 7: // 128
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
//kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
break;
case 8: // 256
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
break;
case 9: // 512
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,16,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
break;
case 10: // 1024
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,32,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
break;
case 11: // 2048
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,64,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,64,32,1>;
break;
default:
return false;
}
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int totalElements, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, cudaStream_t streamid)// p is the probability to keep, not drop
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1);
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
#ifdef OLD_GENERATOR
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
#endif
}
// compute launch size
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
kernel<<<blocks, threads, 0, streamid>>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride, rng_engine_inputs, p);
return true;
}
return false;
......@@ -1214,8 +1664,6 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_inp
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)
{
......@@ -1296,81 +1744,285 @@ __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const in
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count)
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x % WARP_SIZE;
//vectorize if a row length is multiple of 4
int flag_vec4 = element_count & 3 == 0;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
input_t elements_input[WARP_BATCH][WARP_ITERATIONS] ;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
softmax_input += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const input_t* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
grad_reg[i][it+element] = acc_t(0);
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], softmax_input + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
uint8_t mask_temp[ELEMENTS_PER_LDG_STG];
input_t grad_temp[ELEMENTS_PER_LDG_STG];
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0], mask + itr_idx);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0], grad + itr_idx);
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
grad_reg[i][it+element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale );
}
}
}
}
// load data from global memory
// convert input_t to acc_t
// TODO : remove this, input is already acc_t type in register
acc_t elements[WARP_BATCH][WARP_ITERATIONS] ;
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it ++) {
elements[i][it] = elements[i][it] / sum[i];
grad_reg[i][it] = grad_reg[i][it] * elements[i][it];
}
}
acc_t grad_sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
grad_sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
grad_sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t grad_input_reg[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element=0; element<ELEMENTS_PER_LDG_STG; element++) {
if (is_log_softmax) {
grad_input_reg[element] = (grad_reg[i][it+element] - std::exp(elements[i][it+element]) * grad_sum[i]);
} else {
grad_input_reg[element] = (grad_reg[i][it+element] - elements[i][it+element] * grad_sum[i]);
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
using masked_scale_softmax_warp_backward_recompute_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count);
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
bool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,1,1, is_log_softmax>;
break;
case 1: // 2
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,2,1, is_log_softmax>;
break;
case 2: // 4
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,4,1, is_log_softmax>;
break;
case 3: // 8
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,8,1, is_log_softmax>;
break;
case 4: // 16
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,16,1, is_log_softmax>;
break;
case 5: // 32
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,32,1, is_log_softmax>;
break;
case 6: // 64
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,2,32,1, is_log_softmax>;
break;
case 7: // 128
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,4,32,1, is_log_softmax>;
break;
case 8: // 256
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,1, is_log_softmax>;
break;
case 9: // 512
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,1, is_log_softmax>;
break;
case 10: // 1024
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,1, is_log_softmax>;
break;
case 11: // 2048
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,1, is_log_softmax>;
break;
default:
return false;
}
return true;
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
bool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int pad_batch_stride, int batch_count, cudaStream_t streamid)
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> kernel;
int warp_size, batches_per_warp;
if (!masked_scale_softmax_warp_backward_recompute_kernel<input_t, output_t, acc_t, is_log_softmax>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
// compute launch size
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
// launch
kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, softmax_elements_stride, pad_batch_stride, softmax_elements);
return true;
}
return false;
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)
{
......
......@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
mask_additive_t= torch.tensor([mask_additive])
if use_biases_t[0]:
if not mask_additive:
......@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else:
input_lin_results, \
softmax_results, \
bmm1_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
......@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
null_tensor, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else:
......@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
......@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
matmul2_results, \
dropout_results, \
softmax_results, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
......@@ -101,24 +134,45 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
if not mask_additive_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias_additive_mask.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
bmm1_results, \
pad_mask, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else:
input_bias_grads = None
output_bias_grads = None
......
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_additive_mask(self) :
grads = torch.randn_like(self.tst_inputs)
mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment