Unverified Commit 1203099a authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent 3c8f5161
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
dropout_prob);
}
} // end namespace mask_softmax_dropout
} // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
m.def("forward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace additive_mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
This diff is collapsed.
......@@ -5,152 +5,130 @@ namespace multihead_attn {
namespace encdec {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob
);
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
inputs_q, inputs_kv, input_weights_q, input_weights_kv,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace cublas_gemmex
} // end namespace encdec
} // end namespace encdec
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward.");
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd,
"Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd,
"Encdec Multihead Attention Backward.");
}
......@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
torch::Tensor const &padding_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
use_mask
? static_cast<const uint8_t *>(padding_mask.data_ptr())
: nullptr,
dropout_prob);
}
} // end namespace mask_softmax_dropout
......@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
} else{
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(padding_mask),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
} else {
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
}
//backward pass is completely in-place
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
#pragma once
//Philox CUDA.
// Philox CUDA.
class Philox {
public:
......@@ -15,28 +15,30 @@ public:
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
// 7-round philox
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
// STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
......@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
return a * b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
......@@ -84,7 +86,7 @@ private:
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
}
......@@ -5,120 +5,98 @@ namespace multihead_attn {
namespace self {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
use_time_mask, is_training, heads, inputs, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace cublas_gemmex
......@@ -126,7 +104,8 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.");
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd,
"Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd,
"Self Multihead Attention Backward.");
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment