Unverified Commit 6c2babf9 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

Fuses dropout and softmax in backward pass, add bias support to CPP MHA, add...


Fuses dropout and softmax in backward pass, add bias support to CPP MHA, add additive mask support, separate Q/K/V parameters (#854)
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent 36c9e904
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
//backward pass is completely in-place
return output_grads;
}
}
}
}
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(padding_mask),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads);
}
//backward pass is completely in-place
return output_grads;
}
}
}
}
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <torch/extension.h>
#include <vector>
#include <cuda_fp16.h>
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
CUDA_R_16F,
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
outputs.copy_(output_biases);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
CUDA_R_16F,
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
outputs.copy_(output_biases);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
...@@ -246,6 +246,230 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i ...@@ -246,6 +246,230 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i
return false; return false;
} }
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG=1>
__global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride)
{
assert(ELEMENTS_PER_LDG_STG==1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ( (first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;
const half* curr_mask = pad_mask + pad_thread_offset;
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
//masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);
//apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
// (__half)-std::numeric_limits<float>::infinity(),
// curr_mask + itr_jmp);
elements_input[i][it] += *(curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1;it < WARP_ITERATIONS;++it) {
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
for (int it = 0;it < WARP_ITERATIONS;++it) {
//elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0;it < WARP_ITERATIONS;it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0;element < ELEMENTS_PER_LDG_STG;++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using additive_masked_softmax_forward_func = void(*)(input_t *dst, const output_t *src, const half *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_kernel(int log2_elements, int &warp_size, int &batches_per_warp, additive_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,1,1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,2,1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,4,1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,8,1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,16,1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,1,32,1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,2,32,1>;
break;
case 7: // 128
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
break;
case 8: // 256
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,8,32,1>;
break;
case 9: // 512
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,16,32,1>;
break;
case 10: // 1024
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,32,32,1>;
break;
default:
return false;
}
return true;
}
template<typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, const input_t *pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride)
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
additive_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
return true;
}
return false;
}
// WARP_BATCH number of batches. // WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data. // WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
...@@ -693,6 +917,511 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8 ...@@ -693,6 +917,511 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
return false; return false;
} }
int log2_ceil_native(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE>
__device__ __forceinline__ void warp_reduce_sum(acc_t* sum) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = sum[i] + b;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward functions as fused variants of at::softmax_backward_data function
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//softmax backward data function is taken from native pytorch, elementwise mul is fused in the epolog, as well as masking and scaling for fusing dropout
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size, int stride, int element_count, int heads)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];
output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
int total_ind = thread_offset + i*element_count + it*WARP_SIZE;
int pad_mask_ind = element_count*(total_ind/(heads * element_count * element_count)) + total_ind%element_count;
uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind];
if (pad_mask_element == 0) gradInput[i*element_count+it*WARP_SIZE] = 0;
else {
if (is_log_softmax) {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 1: // 2
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 2: // 4
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 3: // 8
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 4: // 16
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 5: // 32
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 6: // 64
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 7: // 128
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 8: // 256
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 9: // 512
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
case 10: // 1024
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int batch_size, int stride, int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (input_t)((acc_t)mask[i*element_count+it*WARP_SIZE] * (acc_t)grad[i*element_count+it*WARP_SIZE] * (acc_t)scale )*output[i*element_count+it*WARP_SIZE];
output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_masked_scale_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, const uint8_t *mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
// elementwise multiplication called in at::softmax_backward_data is fused inside softmax dgrad kernel
// as a result of fusion, intermediate multiplication result is stored in fp32 in registers, instead of fp16
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] ;
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] ;
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]*output[i*element_count+it*WARP_SIZE];
output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0]; //* output_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];// * output_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_softmax_backward_fused_native(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward // Warp softmax backward
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
......
from .self_multihead_attn import SelfMultiheadAttn from .self_multihead_attn import SelfMultiheadAttn
from .encdec_multihead_attn import EncdecMultiheadAttn from .encdec_multihead_attn import EncdecMultiheadAttn
from .mask_softmax_dropout_func import fast_mask_softmax_dropout_func
import torch import torch
import fast_self_multihead_attn import fast_self_multihead_attn
import fast_self_multihead_attn_bias
import fast_self_multihead_attn_bias_additive_mask
class FastSelfAttnFunc(torch.autograd.Function) : class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob): def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, pad_mask, mask_additive, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads]) heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
if use_biases_t[0]:
if not mask_additive:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias_additive_mask.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
input_lin_results, \ input_lin_results, \
softmax_results, \ softmax_results, \
dropout_results, \ dropout_results, \
...@@ -27,7 +71,8 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -27,7 +71,8 @@ class FastSelfAttnFunc(torch.autograd.Function) :
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(heads_t, \ ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
...@@ -38,10 +83,12 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -38,10 +83,12 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
use_biases_t, \
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
...@@ -53,6 +100,28 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -53,6 +100,28 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \ dropout_mask, \
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else:
input_bias_grads = None
output_bias_grads = None
input_grads, \ input_grads, \
input_weight_grads, \ input_weight_grads, \
output_weight_grads = \ output_weight_grads = \
...@@ -68,7 +137,6 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -68,7 +137,6 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t[0]) dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads,input_bias_grads, output_bias_grads, None, None, None
return None, None, None, input_grads, input_weight_grads, output_weight_grads, None, None
fast_self_attn_func = FastSelfAttnFunc.apply fast_self_attn_func = FastSelfAttnFunc.apply
import torch
import fast_mask_softmax_dropout
import fast_additive_mask_softmax_dropout
class MaskSoftmaxDropout(torch.autograd.Function) :
@staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
if mask_additive:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_additive_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask if use_mask else null_tensor, \
mask_additive_t, \
dropout_prob_t)
return dropout_results.detach()
@staticmethod
def backward(ctx, output_grads):
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask, \
mask_additive_t, \
dropout_prob_t = ctx.saved_tensors
if mask_additive_t[0]:
input_grads = \
fast_additive_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
dropout_prob_t[0])
else:
input_grads = \
fast_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
pad_mask, \
dropout_prob_t[0])
return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply
...@@ -21,7 +21,7 @@ class SelfMultiheadAttn(nn.Module): ...@@ -21,7 +21,7 @@ class SelfMultiheadAttn(nn.Module):
See "Attention Is All You Need" for more details. See "Attention Is All You Need" for more details.
""" """
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'): def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast', separate_qkv_params=False, mask_additive=False):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -32,17 +32,38 @@ class SelfMultiheadAttn(nn.Module): ...@@ -32,17 +32,38 @@ class SelfMultiheadAttn(nn.Module):
self.include_norm_add = include_norm_add self.include_norm_add = include_norm_add
self.impl = impl self.impl = impl
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.separate_qkv_params = separate_qkv_params
self.mask_additive = mask_additive
if mask_additive:
assert self.include_norm_add == False, "additive mask not supported with layer norm"
assert impl == 'default' or (impl == 'fast' and bias), "additive mask not supported for fast mode without bias"
if separate_qkv_params:
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
else:
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim)) self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias: if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!" if separate_qkv_params:
self.q_bias = Parameter(torch.Tensor(embed_dim))
self.k_bias = Parameter(torch.Tensor(embed_dim))
self.v_bias = Parameter(torch.Tensor(embed_dim))
else:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim)) self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
if separate_qkv_params:
self.register_parameter('q_bias', None)
self.register_parameter('k_bias', None)
self.register_parameter('v_bias', None)
self.q_bias = None
self.k_bias = None
self.v_bias = None
else: else:
self.register_parameter('in_proj_bias', None) self.register_parameter('in_proj_bias', None)
self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None self.in_proj_bias = None
self.register_parameter('out_proj_bias', None)
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add: if self.include_norm_add:
if impl == 'fast': if impl == 'fast':
...@@ -67,9 +88,19 @@ class SelfMultiheadAttn(nn.Module): ...@@ -67,9 +88,19 @@ class SelfMultiheadAttn(nn.Module):
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self): def reset_parameters(self):
if self.separate_qkv_params:
nn.init.xavier_uniform_(self.q_weight)
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight) nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
if self.separate_qkv_params:
nn.init.constant_(self.q_bias, 0.)
nn.init.constant_(self.k_bias, 0.)
nn.init.constant_(self.v_bias, 0.)
else:
nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add: if self.include_norm_add:
...@@ -88,10 +119,22 @@ class SelfMultiheadAttn(nn.Module): ...@@ -88,10 +119,22 @@ class SelfMultiheadAttn(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
if self.separate_qkv_params:
input_weights = torch.cat([self.q_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.k_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.v_weight.view(self.num_heads,1,self.head_dim,self.embed_dim)], dim=1).reshape(3*self.embed_dim,self.embed_dim).contiguous()
else:
input_weights = self.in_proj_weight
if self.bias:
if self.separate_qkv_params:
input_bias = torch.cat([self.q_bias.view(self.num_heads,1,self.head_dim), self.k_bias.view(self.num_heads,1,self.head_dim), self.v_bias.view(self.num_heads,1,self.head_dim)],dim=1).reshape(3*self.embed_dim).contiguous()
else:
input_bias = self.in_proj_bias
else:
input_bias=None
if key_padding_mask is not None: if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None: elif attn_mask is not None:
assert self.mask_additive == False, "additive mask not supported for time mask"
mask = attn_mask mask = attn_mask
else: else:
mask = None mask = None
...@@ -100,12 +143,12 @@ class SelfMultiheadAttn(nn.Module): ...@@ -100,12 +143,12 @@ class SelfMultiheadAttn(nn.Module):
if self.impl == 'fast': if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) input_weights, self.out_proj_weight, mask, self.dropout)
else: else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight, input_weights, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, input_bias, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
if is_training: if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
...@@ -114,11 +157,11 @@ class SelfMultiheadAttn(nn.Module): ...@@ -114,11 +157,11 @@ class SelfMultiheadAttn(nn.Module):
else: else:
if self.impl == 'fast': if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) input_weights, self.out_proj_weight, input_bias, self.out_proj_bias, mask, self.mask_additive, self.dropout)
else: else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight, input_weights, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, input_bias, self.out_proj_bias,
mask, self.dropout) mask, self.mask_additive, self.dropout)
return outputs,None return outputs,None
import torch
import unittest
import torch.nn.functional as F
from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func
class FusedSoftmaxTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda()
self.mask = self.mask.half()*-10000
self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
def test_fused_softmax(self) :
grads = torch.randn_like(self.tst_inputs)
y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)
y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)
y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length)
y_ref = F.softmax(y_ref, dim=-1)
y_ref = torch._fused_dropout(y_ref, 1.0)
y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0)
y_ref[0].backward(grads)
y_tst.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref[0], y_tst, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
...@@ -243,6 +243,58 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -243,6 +243,58 @@ if "--fast_multihead_attn" in sys.argv:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment