Unverified Commit 1203099a authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent 3c8f5161
#include <torch/extension.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob);
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); dropout_prob);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward",
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob) {
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); const int q_seq_len = input.size(1);
...@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1); const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len; const int k_seq_len = q_seq_len;
...@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda( ...@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, // backward pass is completely in-place
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace additive_mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
...@@ -11,33 +11,22 @@ ...@@ -11,33 +11,22 @@
const int UNROLL = 4; const int UNROLL = 4;
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void
typename accscalar_t, apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs,
typename IndexType uint8_t *mask, IndexType totalElements, accscalar_t p,
> std::pair<uint64_t, uint64_t> seeds) {
__global__ void apex_fused_dropout_kernel(scalar_t const *inputs, accscalar_t pinv = accscalar_t(1) / p;
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init( curand_init(seeds.first, idx, seeds.second, &state);
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
rand.x = rand.x <= p; rand.x = rand.x <= p;
...@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = src[ii]*(&rand.x)[ii]*pinv; outputs[li] = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_dropout_add_kernel(scalar_t const *inputs, __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs, scalar_t const *add_inputs,
scalar_t *outputs, scalar_t *outputs, uint8_t *mask,
uint8_t *mask, IndexType totalElements, accscalar_t p,
IndexType totalElements, std::pair<uint64_t, uint64_t> seeds) {
accscalar_t p, accscalar_t pinv = accscalar_t(1) / p;
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init( curand_init(seeds.first, idx, seeds.second, &state);
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t add_src[UNROLL]; scalar_t add_src[UNROLL];
...@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1); outputs[li] =
static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void apex_add_kernel(scalar_t const *inputs,
typename accscalar_t, scalar_t const *add_inputs, scalar_t *outputs,
typename IndexType IndexType totalElements) {
>
__global__ void apex_add_kernel( scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t add_src[UNROLL]; scalar_t add_src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
...@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs, ...@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs,
} }
} }
template<typename scalar_t, template <typename scalar_t, typename accscalar_t, typename IndexType>
typename accscalar_t,
typename IndexType
>
__global__ void apex_masked_scale_kernel(scalar_t const *inputs, __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
scalar_t *outputs, scalar_t *outputs, uint8_t const *mask,
uint8_t const *mask,
IndexType totalElements, IndexType totalElements,
accscalar_t scale accscalar_t scale) {
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) for (IndexType linearIndex = idx; linearIndex < rounded_size;
{ linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t msk[UNROLL]; scalar_t msk[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
...@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs, ...@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]); outputs[li] = static_cast<accscalar_t>(src[ii]) * scale *
static_cast<accscalar_t>(msk[ii]);
} }
} }
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs,
typename accscalar_t, uint8_t *mask, IndexType totalElements,
typename IndexType accscalar_t p) {
>
void apex_fused_dropout_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state // number of times random will be generated per thread, to offset philox
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; // counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs, ...@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else #else
std::lock_guard<std::mutex> lock(gen.mutex()); std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset); rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif #endif
} }
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs); apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, p, rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
typename accscalar_t, scalar_t *outputs, uint8_t *mask,
typename IndexType IndexType totalElements, accscalar_t p) {
>
void apex_dropout_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state // number of times random will be generated per thread, to offset philox
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; // counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs, ...@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else #else
std::lock_guard<std::mutex> lock(gen.mutex()); std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset); rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif #endif
} }
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs); apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, mask, totalElements, p,
rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
typename accscalar_t, scalar_t *outputs, IndexType totalElements) {
typename IndexType
>
void apex_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements); apex_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, totalElements);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template<typename scalar_t, template <typename scalar_t, typename accscalar_t, typename IndexType>
typename accscalar_t, void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs,
typename IndexType uint8_t const *mask, IndexType totalElements,
> accscalar_t scale) {
void apex_masked_scale_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale); apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, scale);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
...@@ -5,103 +5,79 @@ namespace multihead_attn { ...@@ -5,103 +5,79 @@ namespace multihead_attn {
namespace encdec { namespace encdec {
namespace cublas_gemmex { namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
int heads, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& inputs_q, float dropout_prob) {
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, input_weights_q, input_weights_kv, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs_q, dropout_prob);
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd( ...@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_q_results, input_lin_kv_results,
output_grads, inputs_q, inputs_kv, input_weights_q, input_weights_kv,
matmul2_results, output_weights, dropout_mask, dropout_prob);
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
...@@ -151,6 +127,8 @@ std::vector<torch::Tensor> bwd( ...@@ -151,6 +127,8 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward."); m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd,
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward."); "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd,
"Encdec Multihead Attention Backward.");
} }
...@@ -5,66 +5,49 @@ namespace multihead_attn { ...@@ -5,66 +5,49 @@ namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace cublas_gemmex { namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& inputs_q, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &input_weights_kv,
torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob);
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob) {
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
...@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd( ...@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
is_training, input_weights_kv, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
inputs_q, : nullptr,
inputs_kv, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd( ...@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_q_results, input_lin_kv_results,
output_grads, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
matmul2_results, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
dropout_results, input_weights_q, input_weights_kv, output_weights,
softmax_results, dropout_mask, dropout_add_mask, dropout_prob);
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
...@@ -192,7 +167,9 @@ std::vector<torch::Tensor> bwd( ...@@ -192,7 +167,9 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd,
m.def("backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def(
"backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
...@@ -5,81 +5,66 @@ namespace multihead_attn { ...@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob);
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob);
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask,
torch::Tensor const& output_grads, torch::Tensor const &padding_mask, float dropout_prob) {
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
heads, use_mask
output_grads, ? static_cast<const uint8_t *>(padding_mask.data_ptr())
softmax_results, : nullptr,
dropout_mask, dropout_prob);
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // end namespace mask_softmax_dropout
...@@ -87,7 +72,8 @@ torch::Tensor bwd( ...@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob) {
const uint8_t* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); const int q_seq_len = input.size(1);
...@@ -41,64 +34,55 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -41,64 +34,55 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob) {
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1); const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len; const int k_seq_len = q_seq_len;
...@@ -110,38 +94,31 @@ torch::Tensor bwd_cuda( ...@@ -110,38 +94,31 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, } else {
attn_batches*q_seq_len, stream); dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
} else{ false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>( static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
static_cast<uint8_t const*>(padding_mask), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
} }
//backward pass is completely in-place // backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
#pragma once #pragma once
//Philox CUDA. // Philox CUDA.
class Philox { class Philox {
public: public:
...@@ -15,28 +15,30 @@ public: ...@@ -15,28 +15,30 @@ public:
incr_n(offset / 4); incr_n(offset / 4);
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
if(STATE == 0) { if (STATE == 0) {
uint4 counter_ = counter; uint4 counter_ = counter;
uint2 key_ = key; uint2 key_ = key;
//7-round philox // 7-round philox
for(int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_); counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B); key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
} }
output = single_round(counter_, key_); output = single_round(counter_, key_);
incr(); incr();
} }
//return a float4 directly // return a float4 directly
//unsigned long ret; // unsigned long ret;
//switch(STATE) { // switch(STATE) {
// case 0: ret = output.x; break; // case 0: ret = output.x; break;
// case 1: ret = output.y; break; // case 1: ret = output.y; break;
// case 2: ret = output.z; break; // case 2: ret = output.z; break;
// case 3: ret = output.w; break; // case 3: ret = output.w; break;
//} //}
//STATE = (STATE + 1) % 4; // STATE = (STATE + 1) % 4;
return output; return output;
} }
private: private:
uint4 counter; uint4 counter;
uint4 output; uint4 output;
...@@ -67,7 +69,7 @@ private: ...@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { unsigned int *result_high) {
*result_high = __umulhi(a, b); *result_high = __umulhi(a, b);
return a*b; return a * b;
} }
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0; unsigned int hi0;
...@@ -85,6 +87,6 @@ private: ...@@ -85,6 +87,6 @@ private:
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f #define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
} }
...@@ -5,87 +5,66 @@ namespace multihead_attn { ...@@ -5,87 +5,66 @@ namespace multihead_attn {
namespace self { namespace self {
namespace cublas_gemmex { namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, input_weights, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
heads, dropout_prob);
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -96,29 +75,28 @@ std::vector<torch::Tensor> bwd( ...@@ -96,29 +75,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
...@@ -126,7 +104,8 @@ std::vector<torch::Tensor> bwd( ...@@ -126,7 +104,8 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::cublas_gemmex::fwd,
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward."); "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd,
"Self Multihead Attention Backward.");
} }
...@@ -5,94 +5,70 @@ namespace multihead_attn { ...@@ -5,94 +5,70 @@ namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace cublas_gemmex { namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, // torch::Tensor const& input_biases,
torch::Tensor const& inputs, // torch::Tensor const& output_biases,
torch::Tensor const& input_weights, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -103,37 +79,37 @@ std::vector<torch::Tensor> bwd( ...@@ -103,37 +79,37 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
} // end namespace self } // namespace self_bias
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd,
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd,
"Self Multihead Attention with Bias -- Backward.");
} }
#include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cuda_fp16.h>
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace cublas_gemmex { namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &input_biases,
torch::Tensor const& input_weights, torch::Tensor const &output_biases,
torch::Tensor const& output_weights, const half *pad_mask, float dropout_prob);
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
// torch::Tensor const& softmax_results, // torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& pad_mask, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& input_lin_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, // torch::Tensor const& input_biases,
torch::Tensor const& input_weights, // torch::Tensor const& output_biases,
torch::Tensor const& output_weights, torch::Tensor const &dropout_mask, float dropout_prob);
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(use_mask , "no mask is not supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -107,37 +82,36 @@ std::vector<torch::Tensor> bwd( ...@@ -107,37 +82,36 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, bmm1_results, pad_mask, input_lin_results, inputs,
output_grads, input_weights, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
} // end namespace self } // namespace self_bias_additive_mask
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd,
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); "Self Multihead Attention with Bias -- Forward.");
m.def("backward",
&multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd,
"Self Multihead Attention with Bias -- Backward.");
} }
...@@ -5,111 +5,86 @@ namespace multihead_attn { ...@@ -5,111 +5,86 @@ namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace cublas_gemmex { namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob);
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
std::vector<torch::Tensor> fwd( #define CHECK_INPUT(x) \
bool use_mask, CHECK_CUDA(x); \
bool use_time_mask, CHECK_CONTIGUOUS(x)
bool is_training,
int heads, std::vector<torch::Tensor>
torch::Tensor const& inputs, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& input_weights, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& output_weights, torch::Tensor const &pad_mask, float dropout_prob) {
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
is_training, lyr_nrm_beta_weights, input_weights, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
inputs, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor>
std::vector<torch::Tensor> bwd( bwd(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& output_grads, torch::Tensor const &softmax_results,
torch::Tensor const& matmul2_results, torch::Tensor const &input_lin_results,
torch::Tensor const& dropout_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& inputs, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -126,40 +101,42 @@ std::vector<torch::Tensor> bwd( ...@@ -126,40 +101,42 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
return bwd_cuda(heads, AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
input_lin_results, "Only HALF is supported");
lyr_nrm_results, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
lyr_nrm_mean, "Only HALF is supported");
lyr_nrm_invvar, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs, "Only HALF is supported");
lyr_nrm_gamma_weights, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
lyr_nrm_beta_weights, "Only BYTE is supported");
input_weights, AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
output_weights, "Only BYTE is supported");
dropout_mask,
dropout_add_mask, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_prob softmax_results, input_lin_results, lyr_nrm_results,
); lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
...@@ -167,7 +144,8 @@ std::vector<torch::Tensor> bwd( ...@@ -167,7 +144,8 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd,
m.def("backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment