Commit 9615983e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by hubertlu-tw
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent d11ddccf
#include <torch/extension.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob);
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
std::vector<torch::Tensor> fwd( #define CHECK_INPUT(x) \
bool use_mask, CHECK_CUDA(x); \
bool is_training, CHECK_CONTIGUOUS(x)
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& output_grads, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& softmax_results, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
float dropout_prob AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
) "Only HALF is supported");
{ AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); "Only HALF is supported");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); // AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); // "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); dropout_prob);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward",
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob) {
torch::Tensor const& input, const int attn_batches = input.size(0);
const half* pad_mask, const int sequences = attn_batches / heads;
float dropout_prob const int q_seq_len = input.size(1);
) const int k_seq_len = q_seq_len;
{ const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; // There is no reason to use more than one stream as every kernel is
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = input.options().requires_grad(false); // by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& softmax_results, const int attn_batches = output_grads.size(0);
torch::Tensor const& dropout_mask, const int q_seq_len = output_grads.size(1);
float dropout_prob const int k_seq_len = q_seq_len;
) const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, // backward pass is completely in-place
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace additive_mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
This diff is collapsed.
...@@ -5,145 +5,121 @@ namespace multihead_attn { ...@@ -5,145 +5,121 @@ namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
int heads, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& inputs_q, float dropout_prob) {
torch::Tensor const& inputs_kv, AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, input_weights_q, input_weights_kv, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs_q, dropout_prob);
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
return bwd_cuda( "Only HALF is supported");
heads, AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
input_lin_q_results, "Only HALF is supported");
input_lin_kv_results, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs_q, "Only HALF is supported");
inputs_kv, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
input_weights_q, "Only BYTE is supported");
input_weights_kv,
output_weights, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_mask, softmax_results, input_lin_q_results, input_lin_kv_results,
dropout_prob inputs_q, inputs_kv, input_weights_q, input_weights_kv,
); output_weights, dropout_mask, dropout_prob);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
......
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs_q.size(2);
const uint8_t* pad_mask, const int sequences = inputs_q.size(1);
float dropout_prob const int q_seq_len = inputs_q.size(0);
) const int k_seq_len = inputs_kv.size(0);
{ const int batches_q = sequences * q_seq_len;
const int embed_dim = inputs_q.size(2); const int batches_kv = sequences * k_seq_len;
const int sequences = inputs_q.size(1); const int head_dim = embed_dim / heads;
const int q_seq_len = inputs_q.size(0); const int output_lin_q_dim = embed_dim;
const int k_seq_len = inputs_kv.size(0); const int output_lin_kv_dim = 2 * embed_dim;
const int batches_q = sequences * q_seq_len; const int attn_batches = heads * sequences;
const int batches_kv = sequences * k_seq_len; const int lead_dim_q = attn_batches * head_dim;
const int head_dim = embed_dim / heads; const int lead_dim_kv = attn_batches * 2 * head_dim;
const int output_lin_q_dim = embed_dim; const int batch_stride_q = head_dim;
const int output_lin_kv_dim = 2 * embed_dim; const int batch_stride_kv = 2 * head_dim;
const int attn_batches = heads * sequences; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int lead_dim_q = attn_batches * head_dim; const float alpha = 1.0;
const int lead_dim_kv = attn_batches * 2 *head_dim; const float beta = 0.0;
const int batch_stride_q = head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs_q.options().requires_grad(false); // by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_q_results =
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_kv_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr()); void *k_lin_results_ptr =
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim); static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
...@@ -166,43 +166,33 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -166,43 +166,33 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -253,78 +243,73 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -253,78 +243,73 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_q_results,
input_lin_q_results, input_lin_kv_results,
input_lin_kv_results, softmax_results,
softmax_results, dropout_results,
dropout_results, dropout_mask,
dropout_mask, matmul2_results,
matmul2_results, outputs};
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv, const int embed_dim = inputs_q.size(2);
torch::Tensor const& input_weights_q, const int sequences = inputs_q.size(1);
torch::Tensor const& input_weights_kv, const int q_seq_len = inputs_q.size(0);
torch::Tensor const& output_weights, const int k_seq_len = inputs_kv.size(0);
torch::Tensor const& dropout_mask, const int batches_q = sequences * q_seq_len;
float dropout_prob const int batches_kv = sequences * k_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_q_dim = embed_dim;
const int embed_dim = inputs_q.size(2); const int output_lin_kv_dim = 2 * embed_dim;
const int sequences = inputs_q.size(1); const int attn_batches = heads * sequences;
const int q_seq_len = inputs_q.size(0); const int lead_dim_q = attn_batches * head_dim;
const int k_seq_len = inputs_kv.size(0); const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batches_q = sequences * q_seq_len; const int batch_stride_q = head_dim;
const int batches_kv = sequences * k_seq_len; const int batch_stride_kv = 2 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_q_dim = embed_dim; const float alpha = 1.0;
const int output_lin_kv_dim = 2 * embed_dim; const float beta = 0.0;
const int attn_batches = heads * sequences; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_q_grads = torch::empty_like(inputs_q); torch::Tensor input_q_grads = torch::empty_like(inputs_q);
torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);
torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim; auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr()); static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim; auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
......
...@@ -5,81 +5,66 @@ namespace multihead_attn { ...@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob);
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob);
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input, AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
torch::Tensor const& pad_mask, AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
float dropout_prob "Only HALF is supported");
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask,
torch::Tensor const& output_grads, torch::Tensor const &padding_mask, float dropout_prob) {
torch::Tensor const& softmax_results, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& padding_mask, AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
heads, use_mask
output_grads, ? static_cast<const uint8_t *>(padding_mask.data_ptr())
softmax_results, : nullptr,
dropout_mask, dropout_prob);
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // end namespace mask_softmax_dropout
...@@ -87,7 +72,8 @@ torch::Tensor bwd( ...@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob) {
const uint8_t* pad_mask, const int attn_batches = input.size(0);
float dropout_prob const int sequences = attn_batches / heads;
) const int q_seq_len = input.size(1);
{ const int k_seq_len = q_seq_len;
const int attn_batches = input.size(0); const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); // There is no reason to use more than one stream as every kernel is
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = input.options().requires_grad(false); // by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob) {
torch::Tensor const& dropout_mask, const int attn_batches = output_grads.size(0);
const uint8_t *padding_mask, const int q_seq_len = output_grads.size(1);
float dropout_prob const int k_seq_len = q_seq_len;
) const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, } else {
attn_batches*q_seq_len, stream); dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
} else{ false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>( static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
static_cast<uint8_t const*>(padding_mask), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
} }
//backward pass is completely in-place // backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
#pragma once #pragma once
//Philox CUDA. // Philox CUDA.
class Philox { class Philox {
public: public:
...@@ -15,28 +15,30 @@ public: ...@@ -15,28 +15,30 @@ public:
incr_n(offset / 4); incr_n(offset / 4);
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
if(STATE == 0) { if (STATE == 0) {
uint4 counter_ = counter; uint4 counter_ = counter;
uint2 key_ = key; uint2 key_ = key;
//7-round philox // 7-round philox
for(int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_); counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B); key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
} }
output = single_round(counter_, key_); output = single_round(counter_, key_);
incr(); incr();
} }
//return a float4 directly // return a float4 directly
//unsigned long ret; // unsigned long ret;
//switch(STATE) { // switch(STATE) {
// case 0: ret = output.x; break; // case 0: ret = output.x; break;
// case 1: ret = output.y; break; // case 1: ret = output.y; break;
// case 2: ret = output.z; break; // case 2: ret = output.z; break;
// case 3: ret = output.w; break; // case 3: ret = output.w; break;
//} //}
//STATE = (STATE + 1) % 4; // STATE = (STATE + 1) % 4;
return output; return output;
} }
private: private:
uint4 counter; uint4 counter;
uint4 output; uint4 output;
...@@ -67,7 +69,7 @@ private: ...@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { unsigned int *result_high) {
*result_high = __umulhi(a, b); *result_high = __umulhi(a, b);
return a*b; return a * b;
} }
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0; unsigned int hi0;
...@@ -84,7 +86,7 @@ private: ...@@ -84,7 +86,7 @@ private:
}; };
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f #define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
} }
#include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cuda_fp16.h>
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &input_biases,
torch::Tensor const& input_weights, torch::Tensor const &output_biases,
torch::Tensor const& output_weights, const half *pad_mask, float dropout_prob);
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, // torch::Tensor const& softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
// torch::Tensor const& softmax_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, // torch::Tensor const& input_biases,
torch::Tensor const& input_lin_results, // torch::Tensor const& output_biases,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_biases, torch::Tensor const& output_biases, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(use_mask , "no mask is not supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& inputs, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, bmm1_results, pad_mask, input_lin_results, inputs,
output_grads, input_weights, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
...@@ -55,28 +52,36 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -55,28 +52,36 @@ std::vector<torch::Tensor> fwd_cuda(
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr()); void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr()); void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda(
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (is_training) { if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>( softmax_success =
reinterpret_cast<half*>(dropout_results_ptr), dispatch_additive_masked_softmax_dropout<half, half, float>(
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr, reinterpret_cast<half *>(dropout_results_ptr),
reinterpret_cast<const half*>(bmm1_results_ptr), (is_training)
pad_mask, ? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
attn_batches*q_seq_len*q_seq_len, : nullptr,
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
k_seq_len, attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
attn_batches*q_seq_len/sequences, 1.0f - dropout_prob, stream);
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function reinterpret_cast<half *>(
reinterpret_cast<const half*>(bmm1_results_ptr), dropout_results_ptr), // this is actually softmax results, but
pad_mask, // making it consistent for the next function
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len / sequences);
attn_batches*q_seq_len/sequences);
} }
// Matmul2 // Matmul2
...@@ -211,73 +213,63 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,73 +213,63 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, bmm1_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results, const int embed_dim = inputs.size(2);
torch::Tensor const& inputs, const int sequences = inputs.size(1);
torch::Tensor const& input_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& output_weights, const int k_seq_len = q_seq_len;
torch::Tensor const& dropout_mask, const int batches = sequences * q_seq_len;
float dropout_prob const int head_dim = embed_dim / heads;
) const int output_lin_dim = 3 * embed_dim;
{ const int attn_batches = heads * sequences;
const int embed_dim = inputs.size(2); const int lead_dim = attn_batches * 3 * head_dim;
const int sequences = inputs.size(1); const int batch_stride = 3 * head_dim;
const int q_seq_len = inputs.size(0); const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int k_seq_len = q_seq_len; const float alpha = 1.0;
const int batches = sequences * q_seq_len; const float beta = 0.0;
const int head_dim = embed_dim / heads; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,127 +5,102 @@ namespace multihead_attn { ...@@ -5,127 +5,102 @@ namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, // torch::Tensor const& input_biases,
torch::Tensor const& inputs, // torch::Tensor const& output_biases,
torch::Tensor const& input_weights, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_biases, torch::Tensor const& output_biases, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
torch::Tensor const& input_biases, const int sequences = inputs.size(1);
torch::Tensor const& output_biases, const int q_seq_len = inputs.size(0);
const uint8_t* pad_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta_zero = 0.0;
const int output_lin_dim = 3 * embed_dim; const float beta_one = 1.0;
const int attn_batches = heads * sequences; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; // There is no reason to use more than one stream as every kernel is
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
...@@ -223,72 +213,63 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -223,72 +213,63 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta = 0.0;
const int output_lin_dim = 3 * embed_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len,
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,120 +5,98 @@ namespace multihead_attn { ...@@ -5,120 +5,98 @@ namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& inputs, torch::Tensor const& input_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, input_weights, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
heads, dropout_prob);
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
return bwd_cuda( AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
heads, "Only HALF is supported");
output_grads, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
matmul2_results, "Only HALF is supported");
dropout_results, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
softmax_results, "Only HALF is supported");
input_lin_results, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
inputs, "Only BYTE is supported");
input_weights,
output_weights, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_mask, softmax_results, input_lin_results, inputs, input_weights,
dropout_prob output_weights, dropout_mask, dropout_prob);
);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward."); m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
} }
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment