Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
......@@ -41,64 +34,55 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
......@@ -110,38 +94,31 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
} else{
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(padding_mask),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
} else {
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
}
//backward pass is completely in-place
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
#pragma once
//Philox CUDA.
// Philox CUDA.
class Philox {
public:
......@@ -15,28 +15,30 @@ public:
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
// 7-round philox
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
// STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
......@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
return a * b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
......@@ -85,6 +87,6 @@ private:
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
}
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
#include <cuda_fp16.h>
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
torch::Tensor const &input_biases,
torch::Tensor const &output_biases,
const half *pad_mask, float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
// torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -107,29 +82,26 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
bmm1_results, pad_mask, input_lin_results, inputs,
input_weights, output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
......@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self_bias_additive_mask {
......@@ -58,25 +55,33 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor bmm1_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -111,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -136,32 +140,28 @@ std::vector<torch::Tensor> fwd_cuda(
// Padded Softmax
bool softmax_success = false;
if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
attn_batches*q_seq_len*q_seq_len,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
softmax_success =
dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half *>(dropout_results_ptr),
(is_training)
? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
: nullptr,
reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
1.0f - dropout_prob, stream);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(
dropout_results_ptr), // this is actually softmax results, but
// making it consistent for the next function
reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -211,31 +211,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
return {input_lin_results, bmm1_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -266,13 +252,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -335,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -358,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -396,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -419,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -496,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
......
......@@ -5,94 +5,70 @@ namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -103,29 +79,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
......@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -58,24 +48,32 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -110,8 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -136,44 +133,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -223,30 +211,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
return {input_lin_results, softmax_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -277,13 +252,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -346,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -369,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -393,19 +370,16 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -427,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -503,15 +476,11 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -5,87 +5,66 @@ namespace multihead_attn {
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
use_time_mask, is_training, heads, inputs, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -96,29 +75,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemm_ex
......@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -55,24 +47,32 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -132,46 +131,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -219,30 +207,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
return {input_lin_results, softmax_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -273,13 +248,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -341,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -364,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -397,17 +374,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -429,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -514,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -5,111 +5,86 @@ namespace multihead_attn {
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &pad_mask, float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -126,40 +101,42 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace cublas_gemmex
......@@ -170,4 +147,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -58,7 +50,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
......@@ -67,22 +60,29 @@ std::vector<torch::Tensor> fwd_cuda(
torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results= torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -90,16 +90,15 @@ std::vector<torch::Tensor> fwd_cuda(
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs.data_ptr()),
HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half *>(inputs.data_ptr()),
static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2
1.0e-5,
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......@@ -129,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -155,46 +153,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -245,57 +232,38 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens,
apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens,
(1.0f - dropout_prob));
} else {
apex_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
total_tokens);
apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()), total_tokens);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results,
softmax_results, dropout_results, dropout_mask, matmul2_results,
dropout_add_mask, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -331,13 +299,17 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
torch::Tensor input_lin_grads = torch::empty_like(inputs);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -347,11 +319,10 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens,
apex_masked_scale_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_grads.data_ptr()),
static_cast<at::Half *>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const *>(dropout_add_mask.data_ptr()), total_tokens,
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
......@@ -407,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -430,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -463,17 +432,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -495,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -572,33 +537,26 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
static_cast<const half*>(input_lin_grads.data_ptr()),
static_cast<half const*>(output_grads.data_ptr()),
static_cast<const float*>(lyr_nrm_mean.data_ptr()),
static_cast<const float*>(lyr_nrm_invvar.data_ptr()),
inputs,
HostLayerNormGradient<half, float>(
static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const *>(output_grads.data_ptr()),
static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2
static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),
1.0e-5,
static_cast<half*>(input_grads.data_ptr()),
static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
static_cast<const half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half *>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5,
static_cast<half *>(input_grads.data_ptr()),
static_cast<half *>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half *>(lyr_nrm_beta_grads.data_ptr()));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads
};
return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads,
input_weight_grads, output_weight_grads};
}
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
#pragma once
#include "philox.h"
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#include "philox.h"
#include <assert.h>
#include <cfloat>
#include <cmath>
#include <cuda_fp16.h>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
......@@ -16,59 +18,81 @@
#else
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst,
const __half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst,
const __half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value, const uint8_t *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value,
const uint8_t *src);
template <>
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, const uint8_t *src) {
if (*src == 1) { *dst = value; }
template <>
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value,
const uint8_t *src) {
if (*src == 1) {
*dst = value;
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst, const Datatype *additive_mask);
template <>
__device__ __inline__ void apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst,
const Datatype *additive_mask);
template <>
__device__ __inline__ void
apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
}
template <>
__device__ __inline__ void apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
}
template <>
__device__ __inline__ void
apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
*(dst+1) += *(additive_mask+1);
*(dst+2) += *(additive_mask+2);
*(dst+3) += *(additive_mask+3);}
} // namespace anonymous
*(dst + 1) += *(additive_mask + 1);
*(dst + 2) += *(additive_mask + 2);
*(dst + 3) += *(additive_mask + 3);
}
} // namespace
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>
__global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batch_size, int stride, int element_count)
{
assert(ELEMENTS_PER_LDG_STG==1);
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void softmax_warp_forward(input_t *dst, const output_t *src,
int batch_size, int stride,
int element_count) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
......@@ -78,7 +102,8 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
......@@ -86,26 +111,27 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
elements_input[i][it + element] = -std::numeric_limits<float>::infinity();
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements_input[i][it + element] =
-std::numeric_limits<float>::infinity();
}
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it], src + i * element_count + it * WARP_SIZE);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -116,86 +142,90 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using softmax_forward_func = void(*)(input_t *dst, const output_t *src, int batch_size, int stride, int element_count);
using softmax_forward_func = void (*)(input_t *dst, const output_t *src,
int batch_size, int stride,
int element_count);
template <typename input_t, typename output_t, typename acc_t>
bool warp_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_forward_func<input_t, output_t> &kernel) {
bool warp_softmax_kernel(int log2_elements, int &warp_size,
int &batches_per_warp,
softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -205,37 +235,37 @@ bool warp_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_war
switch (log2_elements) {
case 0: // 1
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
......@@ -243,19 +273,22 @@ bool warp_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_war
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements,
int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
......@@ -271,29 +304,35 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, batch_count, softmax_elements_stride, softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)
{
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward_vec4(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int batch_size, int stride, int element_count,
int pad_batch_stride, at::PhiloxCudaState philox_args, float p) {
assert(ELEMENTS_PER_LDG_STG==4);
assert(ELEMENTS_PER_LDG_STG == 4);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
acc_t pinv = acc_t(1)/p;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x;
acc_t pinv = acc_t(1) / p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
//vectorize if element_count is multiple of 4, else don't vectorize
// vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
......@@ -302,34 +341,38 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0;i < WARP_BATCH;++i) {
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const half *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
curr_mask +
itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -340,48 +383,49 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -389,65 +433,71 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
uint8_t rands[WARP_BATCH][WARP_ITERATIONS];
float4 rand_num;
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it+=ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
rand_num = uniform4(ph());
rands[i][it] = (rand_num.x <= p) > 0.5;
rands[i][it+1] = (rand_num.y <= p) > 0.5;
rands[i][it+2] = (rand_num.z <= p) > 0.5;
rands[i][it+3] = (rand_num.w <= p) > 0.5;
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);
rands[i][it + 1] = (rand_num.y <= p) > 0.5;
rands[i][it + 2] = (rand_num.z <= p) > 0.5;
rands[i][it + 3] = (rand_num.w <= p) > 0.5;
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(
dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);
}
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
out[element] = rands[i][it+element] * (pinv * (elements[i][it + element] / sum[i]));
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = rands[i][it + element] *
(pinv * (elements[i][it + element] / sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
}
else {
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)
{
assert(ELEMENTS_PER_LDG_STG==1);
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int batch_size, int stride, int element_count,
int pad_batch_stride, at::PhiloxCudaState philox_args, float p) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
acc_t pinv = acc_t(1)/p;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x;
acc_t pinv = acc_t(1) / p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
//vectorize if element_count is multiple of 4, else don't vectorize
// vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + local_idx;
......@@ -456,16 +506,17 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0;i < WARP_BATCH;++i) {
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += 1) {
int pad_thread_offset =
((first_batch + i) / pad_batch_stride) * stride + local_idx;
const half *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += 1) {
int element_index = local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < 1;++element) {
//masking_value is a large negative value
#pragma unroll
for (int element = 0; element < 1; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
......@@ -473,17 +524,17 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, 1>(&elements_input[i][it], curr_mask + itr_jmp);
apply_additive_mask<input_t, 1>(&elements_input[i][it],
curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -494,86 +545,85 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
auto seeds = at::cuda::philox::unpack(philox_args);
curand_init(
std::get<0>(seeds),
tid,
std::get<1>(seeds),
&state);
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += 1) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += 1) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[1];
acc_t softmax_out[1];
uint8_t dropout_mask_temp[1];
//generate a vector of random numbers here
// generate a vector of random numbers here
float rand = curand_uniform(&state);
float *rand_ptr = (float*)(&rand);
#pragma unroll
for (int element = 0;element < 1;++element) {
float *rand_ptr = (float *)(&rand);
#pragma unroll
for (int element = 0; element < 1; ++element) {
softmax_out[element] = (elements[i][it + element] / sum[i]);
rand_ptr[element] = rand_ptr[element] <= p;
out[element] = rand_ptr[element] * pinv * softmax_out[element];
dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f
dropout_mask_temp[element] =
rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f
}
copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);
copy_vector<uint8_t, 1>(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp);
copy_vector<uint8_t, 1>(dropout_mask + i * element_count +
it * WARP_SIZE,
dropout_mask_temp);
}
else {
} else {
break;
}
}
......@@ -581,15 +631,20 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t>
using additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p);
using additive_masked_softmax_dropout_forward_func = void (*)(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int batch_size, int stride, int element_count,
int pad_batch_stride, at::PhiloxCudaState philox_args, float p);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> &kernel) {
bool warp_additive_masked_softmax_dropout_kernel(
int element_count, int log2_elements, int &warp_size, int &batches_per_warp,
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t>
&kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -599,45 +654,77 @@ bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_ele
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 2, 4, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 8, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,16,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 16, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,32,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 32, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 32, 32, 1>;
break;
case 11: // 2048
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,64,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1,64,32,1>;
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 64, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 64, 32, 1>;
break;
default:
return false;
......@@ -645,22 +732,29 @@ bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_ele
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int totalElements, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, cudaStream_t streamid)// p is the probability to keep, not drop
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_dropout(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int totalElements, int softmax_elements,
int softmax_elements_stride, int batch_count, int pad_batch_stride, float p,
cudaStream_t streamid) // p is the probability to keep, not drop
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> kernel;
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t>
kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(
softmax_elements, log2_elements, warp_size, batches_per_warp,
kernel)) {
return false;
}
......@@ -671,8 +765,9 @@ bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_ma
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
c10::optional<at::Generator> gen_;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
int64_t counter_offset = (totalElements / (blocks * threads_per_block) + 1);
at::PhiloxCudaState rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
......@@ -683,20 +778,24 @@ bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_ma
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride, rng_engine_inputs, p);
kernel<<<blocks, threads, 0, streamid>>>(
dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride, rng_engine_inputs, p);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>
__global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)
{
assert(ELEMENTS_PER_LDG_STG==1);
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void additive_masked_softmax_warp_forward(
input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size,
int stride, int element_count, int pad_batch_stride) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
......@@ -706,7 +805,8 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
......@@ -715,35 +815,36 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const half *curr_mask = pad_mask + pad_thread_offset;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
//apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
// apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
// (__half)-std::numeric_limits<float>::infinity(),
// curr_mask + itr_jmp);
elements_input[i][it] += *(curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -754,70 +855,71 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
......@@ -825,14 +927,18 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using additive_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const half *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);
using additive_masked_softmax_forward_func = void (*)(
input_t *dst, const output_t *src, const half *pad_mask, int batch_size,
int stride, int element_count, int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_forward_func<input_t, output_t> &kernel) {
bool warp_additive_masked_softmax_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
additive_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -842,37 +948,48 @@ bool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 1, 1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 2, 1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 4, 1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 8, 1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 16, 1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 32, 1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
2, 32, 1>;
break;
case 7: // 128
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
4, 32, 1>;
break;
case 8: // 256
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,
8, 32, 1>;
break;
case 9: // 512
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,
16, 32, 1>;
break;
case 10: // 1024
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,
32, 32, 1>;
break;
default:
return false;
......@@ -880,19 +997,25 @@ bool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src,
const input_t *pad_mask,
int softmax_elements,
int softmax_elements_stride,
int batch_count, int pad_batch_stride) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
additive_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
......@@ -908,24 +1031,31 @@ bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const i
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride);
return true;
}
return false;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_stream(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, cudaStream_t streamid)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_stream(
output_t *dst, const input_t *src, const input_t *pad_mask,
int softmax_elements, int softmax_elements_stride, int batch_count,
int pad_batch_stride, cudaStream_t streamid) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
additive_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
......@@ -937,23 +1067,25 @@ bool dispatch_additive_masked_softmax_stream(output_t *dst, const input_t *src,
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
kernel<<<blocks, threads, 0, streamid>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>
__global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)
{
assert(ELEMENTS_PER_LDG_STG==1);
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void
masked_softmax_warp_forward(input_t *dst, const output_t *src,
const uint8_t *pad_mask, int batch_size, int stride,
int element_count, int pad_batch_stride) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
......@@ -963,7 +1095,8 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
......@@ -972,33 +1105,36 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t* curr_mask = pad_mask + pad_thread_offset;
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t *curr_mask = pad_mask + pad_thread_offset;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
elements_input[i][it + element] = -std::numeric_limits<float>::infinity();
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements_input[i][it + element] =
-std::numeric_limits<float>::infinity();
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
(__half)-std::numeric_limits<float>::infinity(),
curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -1009,70 +1145,71 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
......@@ -1080,14 +1217,20 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);
using masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src,
const uint8_t *pad_mask,
int batch_size, int stride,
int element_count,
int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, masked_softmax_forward_func<input_t, output_t> &kernel) {
bool warp_masked_softmax_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -1097,37 +1240,44 @@ bool warp_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_
switch (log2_elements) {
case 0: // 1
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
......@@ -1135,19 +1285,24 @@ bool warp_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_masked_softmax(output_t *dst, const input_t *src,
const uint8_t *pad_mask, int softmax_elements,
int softmax_elements_stride, int batch_count,
int pad_batch_stride) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
......@@ -1163,20 +1318,24 @@ bool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *p
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>
__global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int mod_seq_len)
{
assert(ELEMENTS_PER_LDG_STG==1);
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void time_masked_softmax_warp_forward(
input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size,
int stride, int element_count, int mod_seq_len) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
......@@ -1186,7 +1345,8 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
......@@ -1195,33 +1355,36 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) % mod_seq_len) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t* curr_mask = pad_mask + pad_thread_offset;
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t *curr_mask = pad_mask + pad_thread_offset;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
elements_input[i][it + element] = -std::numeric_limits<float>::infinity();
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements_input[i][it + element] =
-std::numeric_limits<float>::infinity();
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
(__half)-std::numeric_limits<float>::infinity(),
curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -1232,70 +1395,71 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
......@@ -1303,14 +1467,18 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using time_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int mod_seq_len);
using time_masked_softmax_forward_func =
void (*)(input_t *dst, const output_t *src, const uint8_t *pad_mask,
int batch_size, int stride, int element_count, int mod_seq_len);
template <typename input_t, typename output_t, typename acc_t>
bool warp_time_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, time_masked_softmax_forward_func<input_t, output_t> &kernel) {
bool warp_time_masked_softmax_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
time_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -1320,37 +1488,48 @@ bool warp_time_masked_softmax_kernel(int log2_elements, int &warp_size, int &bat
switch (log2_elements) {
case 0: // 1
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1,
16, 1>;
break;
case 5: // 32
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1,
32, 1>;
break;
case 6: // 64
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2,
32, 1>;
break;
case 7: // 128
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4,
32, 1>;
break;
case 8: // 256
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8,
32, 1>;
break;
case 9: // 512
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16,
32, 1>;
break;
case 10: // 1024
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32,
32, 1>;
break;
default:
return false;
......@@ -1358,19 +1537,24 @@ bool warp_time_masked_softmax_kernel(int log2_elements, int &warp_size, int &bat
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int mod_seq_len)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_time_masked_softmax(output_t *dst, const input_t *src,
const uint8_t *pad_mask, int softmax_elements,
int softmax_elements_stride, int batch_count,
int mod_seq_len) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
time_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_time_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_time_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
......@@ -1386,21 +1570,23 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, mod_seq_len);
return true;
}
return false;
}
static int log2_ceil_native(int value) {
int log2_ceil_native(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
while ((1 << log2_value) < value)
++log2_value;
return log2_value;
}
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return __shfl_xor_sync(mask, value, laneMask, width);
#else
......@@ -1409,10 +1595,10 @@ __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int wid
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE>
__device__ __forceinline__ void warp_reduce_sum(acc_t* sum) {
#pragma unroll
__device__ __forceinline__ void warp_reduce_sum(acc_t *sum) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = sum[i] + b;
......@@ -1421,17 +1607,24 @@ __device__ __forceinline__ void warp_reduce_sum(acc_t* sum) {
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward functions as fused variants of at::softmax_backward_data function
// Warp softmax backward functions as fused variants of
// at::softmax_backward_data function
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//softmax backward data function is taken from native pytorch, elementwise mul is fused in the epolog, as well as masking and scaling for fusing dropout
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size, int stride, int element_count, int heads)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
// softmax backward data function is taken from native pytorch, elementwise mul
// is fused in the epolog, as well as masking and scaling for fusing dropout
template <typename input_t, typename output_t, typename acc_t,
int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_masked_dgrad(
output_t *gradInput, const input_t *grad, const input_t *output,
const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size,
int stride, int element_count, int heads) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
......@@ -1443,7 +1636,8 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradIn
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
......@@ -1453,21 +1647,25 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradIn
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];
output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
grad_reg[i][it] =
(input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] *
(acc_t)grad[i * element_count + it * WARP_SIZE] *
(acc_t)scale) *
output[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
......@@ -1476,55 +1674,68 @@ __global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradIn
}
acc_t sum[WARP_BATCH];
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
int total_ind = thread_offset + i*element_count + it*WARP_SIZE;
int pad_mask_ind = element_count*(total_ind/(heads * element_count * element_count)) + total_ind%element_count;
int total_ind = thread_offset + i * element_count + it * WARP_SIZE;
int pad_mask_ind =
element_count *
(total_ind / (heads * element_count * element_count)) +
total_ind % element_count;
uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind];
if (pad_mask_element == 0) gradInput[i*element_count+it*WARP_SIZE] = 0;
if (pad_mask_element == 0)
gradInput[i * element_count + it * WARP_SIZE] = 0;
else {
if (is_log_softmax) {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out(
output_t *grad_input, const input_t *grad, const input_t *output,
const uint8_t *mask, const uint8_t *pad_mask, acc_t scale,
int softmax_elements, int softmax_elements_stride, int batch_count,
int heads) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
......@@ -1537,48 +1748,81 @@ void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, con
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 1: // 2
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 2: // 4
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 3: // 8
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 4: // 16
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 5: // 32
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 6: // 64
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 7: // 128
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 8: // 256
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 9: // 512
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 10: // 1024
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
default:
break;
......@@ -1586,18 +1830,25 @@ void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, con
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads, cudaStream_t streamid)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out_stream(
output_t *grad_input, const input_t *grad, const input_t *output,
const uint8_t *mask, const uint8_t *pad_mask, acc_t scale,
int softmax_elements, int softmax_elements_stride, int batch_count,
int heads, cudaStream_t streamid) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
......@@ -1608,48 +1859,81 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_inp
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
0, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 1: // 2
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
1, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 2: // 4
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
2, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 3: // 8
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
3, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 4: // 16
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
4, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 5: // 32
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
5, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 6: // 64
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
6, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 7: // 128
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
7, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 8: // 256
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
8, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 9: // 512
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
9, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 10: // 1024
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
10, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
default:
break;
......@@ -1657,12 +1941,18 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t *grad_inp
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
template <typename input_t, typename output_t, typename acc_t,
int log2_elements, bool is_log_softmax>
__global__ void
masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad,
const input_t *output, const uint8_t *mask,
acc_t scale, int batch_size, int stride,
int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
......@@ -1674,7 +1964,8 @@ __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const in
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
......@@ -1684,21 +1975,25 @@ __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const in
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];
output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
grad_reg[i][it] =
(input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] *
(acc_t)grad[i * element_count + it * WARP_SIZE] *
(acc_t)scale) *
output[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
......@@ -1707,43 +2002,45 @@ __global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const in
}
acc_t sum[WARP_BATCH];
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count)
{
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG,
bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_recompute(
output_t *gradInput, const input_t *grad, const input_t *softmax_input,
const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size,
int stride, int pad_batch_stride, int element_count) {
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
......@@ -1752,12 +2049,13 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
//vectorize if a row length is multiple of 4
// vectorize if a row length is multiple of 4
int flag_vec4 = element_count & 3 == 0;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
input_t elements_input[WARP_BATCH][WARP_ITERATIONS] ;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
......@@ -1767,53 +2065,61 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const input_t* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const input_t *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
grad_reg[i][it+element] = acc_t(0);
grad_reg[i][it + element] = acc_t(0);
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], softmax_input + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
softmax_input + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
curr_mask +
itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
uint8_t mask_temp[ELEMENTS_PER_LDG_STG];
input_t grad_temp[ELEMENTS_PER_LDG_STG];
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0], mask + itr_idx);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0], grad + itr_idx);
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
grad_reg[i][it+element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale );
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0],
mask + itr_idx);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0],
grad + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] =
((acc_t)mask_temp[element] * (acc_t)grad_temp[element] *
(acc_t)scale);
}
}
}
}
// load data from global memory
// convert input_t to acc_t
// TODO : remove this, input is already acc_t type in register
acc_t elements[WARP_BATCH][WARP_ITERATIONS] ;
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
......@@ -1824,109 +2130,119 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it ++) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it++) {
elements[i][it] = elements[i][it] / sum[i];
grad_reg[i][it] = grad_reg[i][it] * elements[i][it];
}
}
acc_t grad_sum[WARP_BATCH];
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
grad_sum[i] = grad_reg[i][0];
#pragma unroll
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
grad_sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);
// store result
#pragma unroll
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t grad_input_reg[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element=0; element<ELEMENTS_PER_LDG_STG; element++) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; element++) {
if (is_log_softmax) {
grad_input_reg[element] = (grad_reg[i][it+element] - std::exp(elements[i][it+element]) * grad_sum[i]);
grad_input_reg[element] =
(grad_reg[i][it + element] -
std::exp(elements[i][it + element]) * grad_sum[i]);
} else {
grad_input_reg[element] = (grad_reg[i][it+element] - elements[i][it+element] * grad_sum[i]);
grad_input_reg[element] = (grad_reg[i][it + element] -
elements[i][it + element] * grad_sum[i]);
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
using masked_scale_softmax_warp_backward_recompute_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count);
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
bool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int log2_elements, int &warp_size, int &batches_per_warp, masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> &kernel) {
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
using masked_scale_softmax_warp_backward_recompute_func = void (*)(
output_t *gradInput, const input_t *grad, const input_t *softmax_input,
const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size,
int stride, int pad_batch_stride, int element_count);
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
bool masked_scale_softmax_warp_backward_recompute_kernel(
int element_count, int log2_elements, int &warp_size, int &batches_per_warp,
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t,
is_log_softmax> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -1936,44 +2252,68 @@ bool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,1,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 1, 1, is_log_softmax>;
break;
case 1: // 2
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,2,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 2, 1, is_log_softmax>;
break;
case 2: // 4
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,4,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 4, 1, is_log_softmax>;
break;
case 3: // 8
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,8,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 8, 1, is_log_softmax>;
break;
case 4: // 16
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,16,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 16, 1, is_log_softmax>;
break;
case 5: // 32
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,1,32,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 32, 1, is_log_softmax>;
break;
case 6: // 64
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,2,32,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 2, 32, 1, is_log_softmax>;
break;
case 7: // 128
kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2,4,32,1, is_log_softmax>;
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 4, 32, 1, is_log_softmax>;
break;
case 8: // 256
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,8,32,1, is_log_softmax>;
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 8, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 8, 32, 1, is_log_softmax>;
break;
case 9: // 512
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,16,32,1, is_log_softmax>;
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 16, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 16, 32, 1, is_log_softmax>;
break;
case 10: // 1024
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,32,32,1, is_log_softmax>;
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 32, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 32, 32, 1, is_log_softmax>;
break;
case 11: // 2048
if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,4, is_log_softmax>;
else kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1,64,32,1, is_log_softmax>;
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 64, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 64, 32, 1, is_log_softmax>;
break;
default:
return false;
......@@ -1981,20 +2321,31 @@ bool masked_scale_softmax_warp_backward_recompute_kernel(int element_count, int
return true;
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
bool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, const input_t *grad, const input_t *softmax_input, const input_t *pad_mask, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int pad_batch_stride, int batch_count, cudaStream_t streamid)
{
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
bool dispatch_masked_scale_softmax_backward_recompute(
output_t *grad_input, const input_t *grad, const input_t *softmax_input,
const input_t *pad_mask, const uint8_t *mask, acc_t scale,
int softmax_elements, int softmax_elements_stride, int pad_batch_stride,
int batch_count, cudaStream_t streamid) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> kernel;
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t,
is_log_softmax>
kernel;
int warp_size, batches_per_warp;
if (!masked_scale_softmax_warp_backward_recompute_kernel<input_t, output_t, acc_t, is_log_softmax>(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {
if (!masked_scale_softmax_warp_backward_recompute_kernel<
input_t, output_t, acc_t, is_log_softmax>(
softmax_elements, log2_elements, warp_size, batches_per_warp,
kernel)) {
return false;
}
......@@ -2009,25 +2360,32 @@ bool dispatch_masked_scale_softmax_backward_recompute(output_t *grad_input, cons
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, softmax_elements_stride, pad_batch_stride, softmax_elements);
kernel<<<blocks, threads, 0, streamid>>>(
grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count,
softmax_elements_stride, pad_batch_stride, softmax_elements);
return true;
}
return false;
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_stream(
output_t *grad_input, const input_t *grad, const input_t *output,
const uint8_t *mask, acc_t scale, int softmax_elements,
int softmax_elements_stride, int batch_count, cudaStream_t streamid) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
......@@ -2038,48 +2396,81 @@ void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const i
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 1: // 2
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 2: // 4
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 3: // 8
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 4: // 16
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 5: // 32
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 6: // 64
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 7: // 128
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 8: // 256
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 9: // 512
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
default:
break;
......@@ -2087,14 +2478,20 @@ void dispatch_masked_scale_softmax_backward_stream(output_t *grad_input, const i
}
}
// elementwise multiplication called in at::softmax_backward_data is fused inside softmax dgrad kernel
// as a result of fusion, intermediate multiplication result is stored in fp32 in registers, instead of fp16
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
// elementwise multiplication called in at::softmax_backward_data is fused
// inside softmax dgrad kernel as a result of fusion, intermediate
// multiplication result is stored in fp32 in registers, instead of fp16
template <typename input_t, typename output_t, typename acc_t,
int log2_elements, bool is_log_softmax>
__global__ void
softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad,
const input_t *output, int batch_size,
int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
......@@ -2106,7 +2503,8 @@ __global__ void softmax_warp_backward_fused_native(output_t *gradInput, const in
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
......@@ -2115,21 +2513,22 @@ __global__ void softmax_warp_backward_fused_native(output_t *gradInput, const in
output += thread_offset;
gradInput += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]*output[i*element_count+it*WARP_SIZE];
output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] *
output[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
......@@ -2138,50 +2537,57 @@ __global__ void softmax_warp_backward_fused_native(output_t *gradInput, const in
}
acc_t sum[WARP_BATCH];
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0]; //* output_reg[i][0];
#pragma unroll
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];// * output_reg[i][it];
sum[i] += grad_reg[i][it]; // * output_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_softmax_backward_fused_native(
output_t *grad_input, const input_t *grad, const input_t *output,
int softmax_elements, int softmax_elements_stride, int batch_count) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
......@@ -2194,48 +2600,81 @@ void dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 1: // 2
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 2: // 4
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 3: // 8
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 4: // 16
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 5: // 32
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 6: // 64
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 7: // 128
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 8: // 256
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 9: // 512
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 10: // 1024
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
default:
break;
......@@ -2247,9 +2686,11 @@ void dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t
// Warp softmax backward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>
__global__ void softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, int batch_size, int stride, int element_count)
{
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void softmax_warp_backward(__half *gradInput, const __half *grad,
const __half *output, int batch_size,
int stride, int element_count) {
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
......@@ -2258,7 +2699,8 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
......@@ -2270,84 +2712,88 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
// load data from global memory
input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t,ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it],
output + i * element_count +
it * WARP_SIZE);
}
}
}
// convert half to floating point
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
grad_reg[i][it] = grad_reg_input[i][it];
output_reg[i][it] = output_reg_input[i][it];
}
}
// compute thread local sum
acc_t sum[WARP_BATCH] = {0};
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += grad_reg[i][it] * output_reg[i][it];
}
}
// reduction sum
constexpr uint32_t FULL_MASK = 0xffffffff;
#pragma unroll
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
out[element] = (output_reg[i][it+element] * (grad_reg[i][it+element] - sum[i]));
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_reg[i][it + element] *
(grad_reg[i][it + element] - sum[i]));
}
// store them in global memory
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using softmax_backward_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count);
using softmax_backward_func = void (*)(output_t *gradInput, const input_t *grad,
const input_t *output, int batch_size,
int stride, int element_count);
template <typename input_t, typename output_t, typename acc_t>
bool warp_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, softmax_backward_func<input_t, output_t> &kernel) {
bool warp_softmax_backward_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
softmax_backward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -2357,37 +2803,37 @@ bool warp_softmax_backward_kernel(int log2_elements, int &warp_size, int &batche
switch (log2_elements) {
case 0: // 1
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,1,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,2,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,4,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,8,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,16,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,1,32,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,2,32,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2,4,32,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,8,32,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,16,32,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1,32,32,1>;
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
......@@ -2395,19 +2841,23 @@ bool warp_softmax_backward_kernel(int log2_elements, int &warp_size, int &batche
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad,
const input_t *output, int softmax_elements,
int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
......@@ -2423,24 +2873,32 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
return true;
}
return false;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad,
const input_t *output,
int softmax_elements,
int softmax_elements_stride,
int batch_count, cudaStream_t streamid) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
......@@ -2452,15 +2910,21 @@ bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad,
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
kernel<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE=32, int ELEMENTS_PER_LDG_STG=1>
__global__ void masked_softmax_warp_backward(__half *gradInput, const __half *grad, const __half *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)
{
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void
masked_softmax_warp_backward(__half *gradInput, const __half *grad,
const __half *output, const uint8_t *pad_mask,
int batch_size, int stride, int element_count,
int pad_batch_stride) {
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
......@@ -2469,7 +2933,8 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
......@@ -2481,91 +2946,97 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
// load data from global memory
input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t,ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it],
output + i * element_count +
it * WARP_SIZE);
}
}
}
// convert half to floating point
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
grad_reg[i][it] = grad_reg_input[i][it];
output_reg[i][it] = output_reg_input[i][it];
}
}
// compute thread local sum
acc_t sum[WARP_BATCH] = {0};
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += grad_reg[i][it] * output_reg[i][it];
}
}
// reduction sum
constexpr uint32_t FULL_MASK = 0xffffffff;
#pragma unroll
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t* curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
out[element] = (output_reg[i][it+element] * (grad_reg[i][it+element] - sum[i]));
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_reg[i][it + element] *
(grad_reg[i][it + element] - sum[i]));
}
// store them in global memory
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
// It is kind of unfortunate this has to be here to zero something out that is close to
// zero in the first place
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&out[0], 0.0, curr_mask + itr_jmp);
// It is kind of unfortunate this has to be here to zero something out
// that is close to zero in the first place
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&out[0], 0.0,
curr_mask + itr_jmp);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + itr_idx, out);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using masked_softmax_backward_func = void(*)(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);
using masked_softmax_backward_func =
void (*)(output_t *gradInput, const input_t *grad, const input_t *output,
const uint8_t *pad_mask, int batch_size, int stride,
int element_count, int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_masked_softmax_backward_kernel(int log2_elements, int &warp_size, int &batches_per_warp, masked_softmax_backward_func<input_t, output_t> &kernel) {
bool warp_masked_softmax_backward_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
masked_softmax_backward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
......@@ -2575,37 +3046,48 @@ bool warp_masked_softmax_backward_kernel(int log2_elements, int &warp_size, int
switch (log2_elements) {
case 0: // 1
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,1,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,2,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,4,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,8,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,16,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,1,32,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,2,32,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2,4,32,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,8,32,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,16,32,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1,32,32,1>;
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
......@@ -2613,19 +3095,26 @@ bool warp_masked_softmax_backward_kernel(int log2_elements, int &warp_size, int
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)
{
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
const input_t *output,
const uint8_t *pad_mask,
int softmax_elements,
int softmax_elements_stride,
int batch_count, int pad_batch_stride) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
masked_softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_masked_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
if (!warp_masked_softmax_backward_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
......@@ -2641,8 +3130,11 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, pad_mask, batch_count,
softmax_elements_stride, softmax_elements, pad_batch_stride);
return true;
}
return false;
}
#include <vector>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
rocblas_datatype a_type = rocblas_datatype_f16_r;
rocblas_datatype b_type = rocblas_datatype_f16_r;
......@@ -25,16 +28,19 @@ rocblas_int flags = 0;
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T;
else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C;
if (trans == 't')
return CUBLAS_OP_T;
else if (trans == 'n')
return CUBLAS_OP_N;
else if (trans == 'c')
return CUBLAS_OP_C;
else {
AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T;
}
}
void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m, long n, long k,
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
......@@ -55,151 +61,73 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m
(int)batchCount, compute_type, algo, solution_index, flags));
}
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount) {
auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
} else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
} else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); }
} else {
AT_ASSERTM(false, "TransA and TransB are invalid");
}
}
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc)
{
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc) {
int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
if(n <= 1)
// Note: leading dimensions generally are checked that they are > 0 and at
// least as big the result requires (even if the value won't be used).
if (n <= 1)
*ldc = std::max<int64_t>(m, 1);
if(transa_)
{
if(m <= 1)
if (transa_) {
if (m <= 1)
*lda = std::max<int64_t>(k, 1);
}
else
{
if(k <= 1)
} else {
if (k <= 1)
*lda = std::max<int64_t>(m, 1);
}
if(transb_)
{
if(k <= 1)
if (transb_) {
if (k <= 1)
*ldb = std::max<int64_t>(n, 1);
}
else
{
if(n <= 1)
} else {
if (n <= 1)
*ldb = std::max<int64_t>(k, 1);
}
}
void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
void HgemmStridedBatched(char transa, char transb, long m,
long n, long k, float alpha, const half *a, long lda,
long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC,
half *d, long ldd, long strideD, long batchCount) {
if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
(ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
{
AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"batchCount"
"with the bound [val] <= %d",
INT_MAX);
}
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount);
// gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
// b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount);
}
/******
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
}
***/
import torch
from torch.nn import init
from apex._autocast_utils import _cast_if_autocast_enabled
import fast_layer_norm
class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, epsilon):
......@@ -17,20 +19,27 @@ class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def backward(ctx, dy):
#assert dy.is_contiguous()
# assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
x, gamma, mu, rsigma = ctx.saved_tensors
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dymat = dy.view(xmat.shape)
dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dx = dxmat.view(x.shape)
return dx, dgamma, dbeta, None
def _fast_layer_norm(x, weight, bias, epsilon):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon)
with torch.cuda.amp.autocast(enabled=False):
return FastLayerNormFN.apply(*args)
class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super(FastLayerNorm, self).__init__()
super().__init__()
self.epsilon = eps
self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))
self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))
......@@ -41,4 +50,4 @@ class FastLayerNorm(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, x):
return FastLayerNormFN.apply(x, self.weight, self.bias, self.epsilon)
return _fast_layer_norm(x, self.weight, self.bias, self.epsilon)
import torch
import unittest
import numpy as np
import sys
import os
import torch.nn.functional as F
from apex.contrib.layer_norm import FastLayerNorm
import numpy as np
import torch
import fast_layer_norm as fln
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
class GPUTimer:
......@@ -14,146 +14,262 @@ class GPUTimer:
self.start_ = torch.cuda.Event(enable_timing=True)
self.stop_ = torch.cuda.Event(enable_timing=True)
self.stream_ = stream
def start(self):
self.stream_.record_event(self.start_)
def stop(self):
self.stream_.record_event(self.stop_)
def sync(self):
self.stream_.synchronize()
def millis(self):
return self.start_.elapsed_time(self.stop_)
def size_in_bytes(t):
return torch.numel(t) * t.element_size()
def abs_err(x, y):
xf = x.float()
yf = y.float()
return ((xf-yf).abs().sum() / yf.abs().sum()).item()
class TestFastLayerNorm(unittest.TestCase):
def setUp(self, seed=1234):
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def test_ln_fp32(self):
self.run_test_layer_norm(torch.float32, atol=1e-5)
def test_ln_fp16(self):
self.run_test_layer_norm(torch.float16, atol=1e-2, rtol=1e-3)
def run_test_layer_norm(self, dtype, atol, rtol=1e-5):
device = torch.device('cuda')
s = 512
b = 32
hidden_size = 1024
epsilon = 1e-5
x = torch.randn((s,b,hidden_size), dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
x.requires_grad = True
beta.requires_grad = True
gamma.requires_grad = True
x2 = x.clone().detach()
beta2 = beta.clone().detach()
gamma2 = gamma.clone().detach()
x2.requires_grad = True
beta2.requires_grad = True
gamma2.requires_grad = True
def metrics(y_ref, y, epsilon=1e-6):
y_ref = y_ref.float()
y = y.float()
relerr, mse = (
(y_ref - y).abs().sum() / (y_ref.abs().sum() + epsilon),
(y_ref - y).square().mean(),
)
return relerr.item(), mse.item()
dummy_label = torch.randn_like(x)
y = F.layer_norm(x, [hidden_size], gamma, beta, epsilon)
device = torch.device("cuda")
fp32 = torch.float32
fp16 = torch.float16
bf16 = torch.bfloat16
diff = y-dummy_label
l = (diff * diff).sum() / b
l.backward()
fln = FastLayerNorm(hidden_size).cuda()
fln.load_state_dict({'bias': beta2, 'weight':gamma2})
if dtype == torch.float16:
fln = fln.half()
def backward_(dz, x, mu, rs, gamma):
y2 = fln(x2)
diff2 = (y2 - dummy_label)
l2 = (diff2 * diff2).sum() / b
wtype = gamma.dtype
itype = x.dtype
otype = dz.dtype
ctype = mu.dtype
mu = mu.unsqueeze(1)
rs = rs.unsqueeze(1)
l2.backward()
hidden_size = gamma.numel()
y = rs * (x.to(ctype) - mu)
dbeta = dz.view(-1, hidden_size).sum(0, dtype=ctype)
dgamma = (dz * y).view(-1, hidden_size).sum(0, dtype=ctype)
dy = dz.view(-1, hidden_size).to(ctype) * gamma.unsqueeze(0).to(ctype)
mdy = dy.mean(1, keepdim=True, dtype=ctype)
self.assertTrue(torch.allclose(y2, y, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(x2.grad, x.grad, atol=atol,rtol=rtol))
self.assertTrue(torch.allclose(fln.bias.grad, beta.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fln.weight.grad, gamma.grad, atol=atol, rtol=rtol))
mdyy = (dy * y).mean(1, keepdim=True, dtype=ctype)
dx = rs * (dy - mdyy * y - mdy)
return dx.to(itype), dgamma.to(wtype), dbeta.to(wtype)
def test_performance(self):
print()
runs = 1000
device = torch.device('cuda')
dtype =torch.float16
s = 512
b = 32
hidden_size = 1024
def benchmark_(S, B, hidden_size, itype, wtype, runs=100):
epsilon = 1e-5
x = torch.randn((s*b,hidden_size), dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
dy = torch.randn_like(x)
x = torch.randn((S * B, hidden_size), dtype=itype, device=device)
beta = torch.randn(hidden_size, dtype=wtype, device=device)
gamma = torch.randn(hidden_size, dtype=wtype, device=device)
dz = torch.randn(x.shape, dtype=wtype, device=device)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
timer = GPUTimer(stream)
#warmup
# warmup
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
timer.start()
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
timer.stop()
timer.sync()
total_bytes_fwd = (size_in_bytes(x)
+ size_in_bytes(y)
+ size_in_bytes(gamma)
+ size_in_bytes(beta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
total_bytes_fwd = sum([size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]])
ms_fwd = timer.millis() / runs
print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd ))
print(
"[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd
)
)
timer.start()
for r in range(runs):
dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)
dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma)
timer.stop()
timer.sync()
total_bytes_bwd = (size_in_bytes(x)
+ size_in_bytes(dx)
+ size_in_bytes(dy)
+ size_in_bytes(gamma)
+ size_in_bytes(dgamma)
+ size_in_bytes(dbeta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
total_bytes_bwd = sum(
[
size_in_bytes(t)
for t in [dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, dbp, dbp, dgp, dgp]
]
)
ms_bwd = timer.millis() / runs
print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd ))
if __name__ == '__main__':
print(
"[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd
)
)
def test_(S, B, hidden_size, itype, wtype, ctype=fp32):
seed = 1243
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
otype = wtype
print("========================================================")
print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}")
print("--------------------------------------------------------")
x = torch.randn(S * B, hidden_size, dtype=itype, device=device)
gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
epsilon = 1e-5
x.requires_grad = True
gamma.requires_grad = True
beta.requires_grad = True
mu_ref = x.mean(1, dtype=ctype, keepdim=True)
v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True)
rs_ref = torch.rsqrt(v + epsilon)
y_ref = rs_ref * (x.to(ctype) - mu_ref)
z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype)
mu_ref = mu_ref.flatten()
rs_ref = rs_ref.flatten()
dz = torch.randn_like(z_ref)
# z_ref.backward(dz)
# dx_ref = x.grad
# dgamma_ref = gamma.grad
# dbeta_ref = beta.grad
dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma)
z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon)
dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma)
re_z, mse_z = metrics(z_ref, z)
re_mu, mse_mu = metrics(mu_ref, mu)
re_rs, mse_rs = metrics(rs_ref, rs)
re_dx, mse_dx = metrics(dx_ref, dx)
re_dg, mse_dg = metrics(dg_ref, dg)
re_db, mse_db = metrics(db_ref, db)
print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}")
print(f"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}")
print(f"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}")
print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}")
print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}")
print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}")
def check_err(x, relerr):
tol = 1e-3 if x.dtype == torch.float16 else 5e-6
return relerr < tol
return [
check_err(x, re)
for x, re in zip([z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db])
]
class TestFastLayerNorm(unittest.TestCase):
def assertAll(self, l):
if not all(l):
print(l)
for x in l:
self.assertTrue(x)
def test_all_configs(self):
hidden_sizes = [
768,
1024,
1536,
2048,
2304,
3072,
3840,
4096,
5120,
6144,
8192,
10240,
12288,
12800,
15360,
16384,
18432,
20480,
24576,
25600,
30720,
32768,
40960,
49152,
65536,
]
for h in hidden_sizes:
with self.subTest(f"hidden_size={h}"):
self.assertAll(test_(256, 2, h, fp32, fp32))
self.assertAll(test_(256, 2, h, fp16, fp16))
self.assertAll(test_(256, 2, h, fp32, fp16))
self.assertAll(test_(256, 2, h, bf16, bf16))
self.assertAll(test_(256, 2, h, fp32, bf16))
def test_run_benchmark(self):
for (S, B, hidden_size, runs) in (
(512, 32, 768, 1000),
(512, 32, 1024, 1000),
(512, 8, 4096, 1000),
(512, 8, 5120, 1000),
(512, 8, 6144, 1000),
(256, 2, 20480, 500),
(256, 2, 25600, 500),
(256, 2, 40960, 250),
(256, 2, 65536, 250),
):
with self.subTest(f"(S, B, hidden_size)=({S}, {B}, {hidden_size})"):
benchmark_(S, B, hidden_size, fp16, fp16, runs)
def test_compat_with_autocast(self):
autocast_dtypes = (
(torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
input_shape = (512, 32, 768)
layer_norm = FastLayerNorm(input_shape[-1]).cuda()
input = torch.randn(input_shape).cuda()
for dtype in autocast_dtypes:
layer_norm.zero_grad(set_to_none=True)
with self.subTest(f"autocast_dtype={dtype}"):
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
out = layer_norm(input)
self.assertEqual(dtype, out.dtype)
grad = torch.randn_like(out)
out.backward(grad)
self.assertEqual(torch.float32, layer_norm.weight.grad.dtype)
if __name__ == "__main__":
unittest.main()
......@@ -3,3 +3,4 @@ from .fused_adam import FusedAdam
from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB
from .fused_adagrad import FusedAdagrad
from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb
import torch
from copy import deepcopy
from itertools import chain
from collections import defaultdict, abc as container_abcs
from apex.multi_tensor_apply import multi_tensor_applier
class FusedMixedPrecisionLamb(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, step=0, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False,
reduced_precision_dtype=None):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
# The learning rate (lr) and optimizer step (step) should be located on device
# in order to faciliated device sync free execution
defaults = dict(lr=torch.tensor(lr, dtype=torch.float32),
step=torch.tensor([step], dtype=torch.int),
bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
tensor_state = ['lr', 'step']
super(FusedMixedPrecisionLamb, self).__init__(params, defaults)
device = self.param_groups[0]['params'][0].device
for idx,group in enumerate(self.param_groups):
for item in tensor_state:
self.param_groups[idx][item] = group[item].to(device=device)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm_mp
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb_mp
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
# Mixed Precision support
self.reduced_precision_dtype = reduced_precision_dtype
self.param_groups_full_precision = []
self._step_supports_amp_scaling = True
self.adam_w_mode = 1 if adam_w_mode else 0
self.use_nvlamb = use_nvlamb
# This method is overridden from the parent class because there is not a way to override
# the nested function cast() that copies a saved piece of state to the device without
# redundantly doing the copy.
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# The original version casted the saved value to the params dtype
# This doesn't work for mixed precision Lamb where the momentum and
# velocity are expected to be in full precision while the params are
# in reduced precision
value = value.to(value.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
def _setup_full_precision_params(self):
for i, pg in enumerate(self.param_groups):
param_list = pg['params']
self.param_groups_full_precision.append({
'params': [
p.clone().detach().to(dtype=torch.float32)
if (self.reduced_precision_dtype is not None) and (p.dtype == self.reduced_precision_dtype)
else None
for p in param_list
],
})
# add_param_groups() is overridden because default items can be tensors. The
# parent version does not clone the default item, so two param groups can
# accidentally point to the same default item value where they can differ
# given they are in separate groups.
def add_param_group(self, param_group):
super().add_param_group(param_group)
for name, default in self.defaults.items():
if isinstance(default, torch.Tensor):
self.param_groups[len(self.param_groups) - 1][name] = default.clone()
@torch.no_grad()
def step(self, closure=None, grad_scaler=None):
loss = None
if closure is not None:
loss = closure()
# The full precision params are set up in the first step of the optimizer
# instead of in the constructor because the full precision params will get out
# out of sync with the model params if DDP syncs the model params across devices
# after the optimizer is constructed.
if len(self.param_groups_full_precision) == 0 :
self._setup_full_precision_params()
# create separate grad lists for params
grad_list = []
for gid,group in enumerate(self.param_groups):
for pid,p in enumerate(group['params']):
assert group['params'][0].dtype == p.dtype, \
"Error: Parameters are not of the identical type: {} != {}".format(
group['params'][0].dtype, p.dtype)
if p.grad is None:
continue
grad_list.append(p.grad)
# Overflow check of gradients
device = self.param_groups[0]["params"][0].device
found_inf = (
grad_scaler._check_inf_per_device(self)[device]
if grad_scaler is not None else torch.zeros((1,), device=device)
)
self._dummy_overflow_buf.copy_(found_inf)
# Get unscale scale factor
scale, inv_scale = None, None
if grad_scaler:
scale = grad_scaler._get_scale_async()
inv_scale = scale.double().reciprocal().float()
else:
scale = torch.ones((1,), device=device)
inv_scale = torch.ones((1,), device=device)
# grad_norm is of scaled gradients.
# So, multiply `max_grad_norm` by scale.
max_grad_norm = self.defaults['max_grad_norm'] * scale
grad_norm = multi_tensor_applier(
self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[grad_list],
False,
)[0]
# Run LAMB optimization math
for gid, (group, group_full) in enumerate(zip(self.param_groups, self.param_groups_full_precision)):
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
group['step'] += (self._dummy_overflow_buf != 1).to(torch.int)
state_lists = [ [], # (0) grads
[], # (1) params
[], # (2) momentum state
[], # (3) velocity state
]
if self.reduced_precision_dtype is not None:
state_lists.append([]) # (4) params reduced_dtype
for p, p_full in zip(group['params'], group_full['params']):
if p.grad is None:
continue
assert not p.grad.is_sparse
state = self.state[p]
# State initialization
if len(state) == 0:
dtype = p.dtype
if self.reduced_precision_dtype is not None and p.dtype == self.reduced_precision_dtype :
dtype = torch.float32
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, dtype=dtype)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=dtype)
if self.reduced_precision_dtype is not None :
state_lists[0].append(p.grad.data)
state_lists[1].append(p_full.data)
state_lists[2].append(state['exp_avg'])
state_lists[3].append(state['exp_avg_sq'])
state_lists[4].append(p.data)
else :
state_lists[0].append(p.grad.data)
state_lists[1].append(p.data)
state_lists[2].append(state['exp_avg'])
state_lists[3].append(state['exp_avg_sq'])
multi_tensor_applier(
self.multi_tensor_lamb,
self._dummy_overflow_buf,
state_lists,
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
grad_norm,
max_grad_norm,
self.use_nvlamb,
found_inf,
inv_scale)
return loss
......@@ -2,4 +2,80 @@
`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module.
The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`.
## Tensor Model Parallel (TP)
APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling.
See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of
PRNG state handling.
## Pipeline Model Parallel (PP)
APEX's pipeline model parallel functions require models to have `.set_input_tensor` because
the input tensor for `.forward` method can be `None`.
The following is a really casual sketch of training script with apex pp.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
class Model(nn.Module):
...
def __init__(self, *args, **kwargs):
super().__init__()
pre_process = kwargs.pop("pre_process")
post_process = kwargs.pop("post_process")
def set_input_tensor(self, tensor):
self.input_tensor = tensor
def forward(self, x, ...):
if parallel_state.is_pipeline_first_stage():
input = x
else:
input = self.input_tensor
...
def model_provider_func(*args, **kwargs):
return Model(*args, **kwargs)
def loss_func(pred, label):
loss = ...
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'nice_loss': averaged_loss}
def forward_step_func(batch, model):
input, label = process_batch(batch)
out = model(input)
return out, partial(loss_func, label)
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)`
model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)
optimizer = ...
data_loader = ...
for epoch in range(num_epochs):
for batch in data_loader:
forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape)
optimizer.step()
```
from . import tensor_parallel
from . import functional
from .enums import LayerType
from .enums import AttnType
from .enums import AttnMaskType
from .parallel_state import (
is_unitialized,
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group,
get_model_parallel_group,
get_tensor_model_parallel_group,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_rank,
set_tensor_model_parallel_rank,
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_tensor_model_parallel_src_rank,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from apex.transformer import amp
from apex.transformer import functional
from apex.transformer import parallel_state
from apex.transformer import pipeline_parallel
from apex.transformer import tensor_parallel
from apex.transformer import utils
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
__all__ = [
"amp",
"functional",
"parallel_state",
"pipeline_parallel",
"tensor_parallel",
"utils",
# enums.py
"LayerType",
"AttnType",
"AttnMaskType",
]
from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler
from apex.transformer._data._batchsampler import MegatronPretrainingSampler
__all__ = [
"MegatronPretrainingRandomSampler",
"MegatronPretrainingSampler",
]
"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support.
Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py.
""" # NOQA
import abc
import torch
__all__ = [
"MegatronPretrainingSampler",
"MegatronPretrainingRandomSampler",
]
class _Base:
"""Base class for Megatron style BatchSampler."""
@abc.abstractmethod
def __len__(self) -> int:
...
@abc.abstractmethod
def __iter__(self):
...
@property
@abc.abstractmethod
def local_minibatch_size(self) -> int:
...
@local_minibatch_size.setter
@abc.abstractclassmethod
def local_minibatch_size(self) -> None:
...
class MegatronPretrainingSampler(_Base):
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
):
# Sanity checks.
if total_samples <= 0:
raise RuntimeError('no sample to consume: {}'.format(self.total_samples))
if consumed_samples >= total_samples:
raise RuntimeError('no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples))
if local_minibatch_size <= 0:
raise RuntimeError(f"local minibatch size must be greater than 0: {local_minibatch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError('data_parallel_rank should be smaller than data size: {}, {}'.format(self.data_parallel_rank, data_parallel_size))
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * data_parallel_size
self.drop_last = drop_last
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.local_minibatch_size
end_idx = start_idx + self.local_minibatch_size
return start_idx, end_idx
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.local_minibatch_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler(_Base):
"""Megatron style Random Batch Sampler.
Major difference is that `__iter__` yields a local minibatch, not a microbatch.
A local minibatch consists of `global_batch_size / data_parallel_size`
Args:
total_samples: The number of data samples, i.e. ``len(dataset)``.
consumed_samples: The number of samples already consumed in pretraining.
local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically
`local_minibatch_size = global_batch_size / data_parallel_size`.
data_parallel_rank:
data_parallel_size:
"""
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
) -> None:
if total_samples <= 0:
raise ValueError(f"no sample to consume: total_samples of {total_samples}")
if local_minibatch_size <= 0:
raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}")
if data_parallel_size <= 0:
raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise ValueError(
f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}"
)
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size
def __len__(self) -> int:
return self.total_samples
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
# note(mkozuki): might be better to uncomment
# assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.local_minibatch_times_data_parallel_size) * self.local_minibatch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.local_minibatch_size:
self.consumed_samples += self.local_minibatch_times_data_parallel_size
yield batch
batch = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment