"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d98a4913eae3a38a879bdcdc8d9a3fe6c28b85c5"
Commit 9615983e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by hubertlu-tw
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent d11ddccf
#include <torch/extension.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob);
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); dropout_prob);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward",
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob) {
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); const int q_seq_len = input.size(1);
...@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1); const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len; const int k_seq_len = q_seq_len;
...@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda( ...@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, // backward pass is completely in-place
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace additive_mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
...@@ -11,33 +11,22 @@ ...@@ -11,33 +11,22 @@
const int UNROLL = 4; const int UNROLL = 4;
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void
typename accscalar_t, apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs,
typename IndexType uint8_t *mask, IndexType totalElements, accscalar_t p,
> std::pair<uint64_t, uint64_t> seeds) {
__global__ void apex_fused_dropout_kernel(scalar_t const *inputs, accscalar_t pinv = accscalar_t(1) / p;
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init( curand_init(seeds.first, idx, seeds.second, &state);
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
rand.x = rand.x <= p; rand.x = rand.x <= p;
...@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = src[ii]*(&rand.x)[ii]*pinv; outputs[li] = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_dropout_add_kernel(scalar_t const *inputs, __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs, scalar_t const *add_inputs,
scalar_t *outputs, scalar_t *outputs, uint8_t *mask,
uint8_t *mask, IndexType totalElements, accscalar_t p,
IndexType totalElements, std::pair<uint64_t, uint64_t> seeds) {
accscalar_t p, accscalar_t pinv = accscalar_t(1) / p;
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init( curand_init(seeds.first, idx, seeds.second, &state);
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t add_src[UNROLL]; scalar_t add_src[UNROLL];
...@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1); outputs[li] =
static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void apex_add_kernel(scalar_t const *inputs,
typename accscalar_t, scalar_t const *add_inputs, scalar_t *outputs,
typename IndexType IndexType totalElements) {
>
__global__ void apex_add_kernel( scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t add_src[UNROLL]; scalar_t add_src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
...@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs, ...@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs,
} }
} }
template<typename scalar_t, template <typename scalar_t, typename accscalar_t, typename IndexType>
typename accscalar_t,
typename IndexType
>
__global__ void apex_masked_scale_kernel(scalar_t const *inputs, __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
scalar_t *outputs, scalar_t *outputs, uint8_t const *mask,
uint8_t const *mask,
IndexType totalElements, IndexType totalElements,
accscalar_t scale accscalar_t scale) {
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) for (IndexType linearIndex = idx; linearIndex < rounded_size;
{ linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t msk[UNROLL]; scalar_t msk[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
...@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs, ...@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]); outputs[li] = static_cast<accscalar_t>(src[ii]) * scale *
static_cast<accscalar_t>(msk[ii]);
} }
} }
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs,
typename accscalar_t, uint8_t *mask, IndexType totalElements,
typename IndexType accscalar_t p) {
>
void apex_fused_dropout_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state // number of times random will be generated per thread, to offset philox
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; // counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs, ...@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else #else
std::lock_guard<std::mutex> lock(gen.mutex()); std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset); rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif #endif
} }
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs); apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, p, rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
typename accscalar_t, scalar_t *outputs, uint8_t *mask,
typename IndexType IndexType totalElements, accscalar_t p) {
>
void apex_dropout_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state // number of times random will be generated per thread, to offset philox
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; // counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs, ...@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else #else
std::lock_guard<std::mutex> lock(gen.mutex()); std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset); rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif #endif
} }
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs); apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, mask, totalElements, p,
rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
typename accscalar_t, scalar_t *outputs, IndexType totalElements) {
typename IndexType
>
void apex_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements); apex_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, totalElements);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template<typename scalar_t, template <typename scalar_t, typename accscalar_t, typename IndexType>
typename accscalar_t, void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs,
typename IndexType uint8_t const *mask, IndexType totalElements,
> accscalar_t scale) {
void apex_masked_scale_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale); apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, scale);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
...@@ -5,103 +5,79 @@ namespace multihead_attn { ...@@ -5,103 +5,79 @@ namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
int heads, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& inputs_q, float dropout_prob) {
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, input_weights_q, input_weights_kv, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs_q, dropout_prob);
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd( ...@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_q_results, input_lin_kv_results,
output_grads, inputs_q, inputs_kv, input_weights_q, input_weights_kv,
matmul2_results, output_weights, dropout_mask, dropout_prob);
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
......
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob) {
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs_q.size(2); const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1); const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0); const int q_seq_len = inputs_q.size(0);
...@@ -48,7 +39,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -48,7 +39,7 @@ std::vector<torch::Tensor> fwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim; const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences; const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim; const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
...@@ -62,25 +53,34 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -62,25 +53,34 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false); auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_q_results =
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_kv_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options); torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr()); void *k_lin_results_ptr =
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim); static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -166,40 +166,30 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -166,40 +166,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} }
...@@ -253,34 +243,24 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -253,34 +243,24 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_q_results,
input_lin_q_results,
input_lin_kv_results, input_lin_kv_results,
softmax_results, softmax_results,
dropout_results, dropout_results,
dropout_mask, dropout_mask,
matmul2_results, matmul2_results,
outputs outputs};
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int embed_dim = inputs_q.size(2); const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1); const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0); const int q_seq_len = inputs_q.size(0);
...@@ -292,7 +272,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -292,7 +272,7 @@ std::vector<torch::Tensor> bwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim; const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences; const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim; const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
...@@ -316,15 +296,20 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -316,15 +296,20 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()); auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim; auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr()); auto q_lin_grads_ptr =
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()); static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
......
...@@ -5,66 +5,49 @@ namespace multihead_attn { ...@@ -5,66 +5,49 @@ namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& inputs_q, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &input_weights_kv,
torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob);
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob) {
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
...@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd( ...@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
is_training, input_weights_kv, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
inputs_q, : nullptr,
inputs_kv, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd( ...@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_q_results, input_lin_kv_results,
output_grads, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
matmul2_results, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
dropout_results, input_weights_q, input_weights_kv, output_weights,
softmax_results, dropout_mask, dropout_add_mask, dropout_prob);
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
...@@ -195,4 +170,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -195,4 +170,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
...@@ -64,7 +61,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -64,7 +61,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false); auto act_options = inputs_q.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
...@@ -73,23 +71,31 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -73,23 +71,31 @@ std::vector<torch::Tensor> fwd_cuda(
torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_q_results =
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_kv_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options); torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr()); void *k_lin_results_ptr =
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim); static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -97,16 +103,15 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -97,16 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()), static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs_q.data_ptr()), static_cast<const at::Half *>(inputs_q.data_ptr()),
static_cast<int>(batches_q), // n1 static_cast<int>(batches_q), // n1
static_cast<int>(embed_dim), // n2 static_cast<int>(embed_dim), // n2
1.0e-5, 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -187,40 +192,30 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -187,40 +192,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} }
...@@ -276,25 +271,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -276,25 +271,22 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()), static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens_q,
total_tokens_q,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} else { } else {
apex_add_cuda<at::Half,float,uint32_t>( apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()), static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()), total_tokens_q);
total_tokens_q);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {lyr_nrm_results,
lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
lyr_nrm_invvar, lyr_nrm_invvar,
input_lin_q_results, input_lin_q_results,
...@@ -304,33 +296,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -304,33 +296,22 @@ std::vector<torch::Tensor> fwd_cuda(
dropout_mask, dropout_mask,
matmul2_results, matmul2_results,
dropout_add_mask, dropout_add_mask,
outputs outputs};
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
const int embed_dim = inputs_q.size(2); const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1); const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0); const int q_seq_len = inputs_q.size(0);
...@@ -343,7 +324,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -343,7 +324,7 @@ std::vector<torch::Tensor> bwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim; const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences; const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim; const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
...@@ -370,16 +351,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -370,16 +351,21 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); at::Tensor input_lin_q_grads = torch::empty_like(inputs_q);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()); auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim; auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr()); auto q_lin_grads_ptr =
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()); static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -505,12 +491,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -505,12 +491,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
...@@ -683,15 +667,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -683,15 +667,9 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads,
input_q_grads, lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads,
input_kv_grads, output_weight_grads};
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,81 +5,66 @@ namespace multihead_attn { ...@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob);
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob);
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask,
torch::Tensor const& output_grads, torch::Tensor const &padding_mask, float dropout_prob) {
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
heads, use_mask
output_grads, ? static_cast<const uint8_t *>(padding_mask.data_ptr())
softmax_results, : nullptr,
dropout_mask, dropout_prob);
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // end namespace mask_softmax_dropout
...@@ -87,7 +72,8 @@ torch::Tensor bwd( ...@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob) {
const uint8_t* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); const int q_seq_len = input.size(1);
...@@ -41,64 +34,55 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -41,64 +34,55 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob) {
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1); const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len; const int k_seq_len = q_seq_len;
...@@ -110,38 +94,31 @@ torch::Tensor bwd_cuda( ...@@ -110,38 +94,31 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, } else {
attn_batches*q_seq_len, stream); dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
} else{ false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>( static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
static_cast<uint8_t const*>(padding_mask), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
} }
//backward pass is completely in-place // backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
#pragma once #pragma once
//Philox CUDA. // Philox CUDA.
class Philox { class Philox {
public: public:
...@@ -15,28 +15,30 @@ public: ...@@ -15,28 +15,30 @@ public:
incr_n(offset / 4); incr_n(offset / 4);
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
if(STATE == 0) { if (STATE == 0) {
uint4 counter_ = counter; uint4 counter_ = counter;
uint2 key_ = key; uint2 key_ = key;
//7-round philox // 7-round philox
for(int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_); counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B); key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
} }
output = single_round(counter_, key_); output = single_round(counter_, key_);
incr(); incr();
} }
//return a float4 directly // return a float4 directly
//unsigned long ret; // unsigned long ret;
//switch(STATE) { // switch(STATE) {
// case 0: ret = output.x; break; // case 0: ret = output.x; break;
// case 1: ret = output.y; break; // case 1: ret = output.y; break;
// case 2: ret = output.z; break; // case 2: ret = output.z; break;
// case 3: ret = output.w; break; // case 3: ret = output.w; break;
//} //}
//STATE = (STATE + 1) % 4; // STATE = (STATE + 1) % 4;
return output; return output;
} }
private: private:
uint4 counter; uint4 counter;
uint4 output; uint4 output;
...@@ -67,7 +69,7 @@ private: ...@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { unsigned int *result_high) {
*result_high = __umulhi(a, b); *result_high = __umulhi(a, b);
return a*b; return a * b;
} }
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0; unsigned int hi0;
...@@ -85,6 +87,6 @@ private: ...@@ -85,6 +87,6 @@ private:
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f #define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
} }
#include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cuda_fp16.h>
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &input_biases,
torch::Tensor const& input_weights, torch::Tensor const &output_biases,
torch::Tensor const& output_weights, const half *pad_mask, float dropout_prob);
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
// torch::Tensor const& softmax_results, // torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& pad_mask, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& input_lin_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, // torch::Tensor const& input_biases,
torch::Tensor const& input_weights, // torch::Tensor const& output_biases,
torch::Tensor const& output_weights, torch::Tensor const &dropout_mask, float dropout_prob);
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(use_mask , "no mask is not supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -107,29 +82,26 @@ std::vector<torch::Tensor> bwd( ...@@ -107,29 +82,26 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, bmm1_results, pad_mask, input_lin_results, inputs,
output_grads, input_weights, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
...@@ -58,25 +55,33 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -58,25 +55,33 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false); auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr()); void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr()); void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda(
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (is_training) { if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>( softmax_success =
reinterpret_cast<half*>(dropout_results_ptr), dispatch_additive_masked_softmax_dropout<half, half, float>(
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr, reinterpret_cast<half *>(dropout_results_ptr),
reinterpret_cast<const half*>(bmm1_results_ptr), (is_training)
pad_mask, ? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
attn_batches*q_seq_len*q_seq_len, : nullptr,
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
k_seq_len, attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
attn_batches*q_seq_len/sequences, 1.0f - dropout_prob, stream);
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function reinterpret_cast<half *>(
reinterpret_cast<const half*>(bmm1_results_ptr), dropout_results_ptr), // this is actually softmax results, but
pad_mask, // making it consistent for the next function
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len / sequences);
attn_batches*q_seq_len/sequences);
} }
// Matmul2 // Matmul2
...@@ -211,31 +213,17 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,31 +213,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, bmm1_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -266,13 +254,17 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -266,13 +254,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,94 +5,70 @@ namespace multihead_attn { ...@@ -5,94 +5,70 @@ namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, // torch::Tensor const& input_biases,
torch::Tensor const& inputs, // torch::Tensor const& output_biases,
torch::Tensor const& input_weights, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -103,29 +79,28 @@ std::vector<torch::Tensor> bwd( ...@@ -103,29 +79,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -58,24 +48,32 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -58,24 +48,32 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false); auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
...@@ -223,30 +213,17 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -223,30 +213,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -277,13 +254,17 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -277,13 +254,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len,
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,87 +5,66 @@ namespace multihead_attn { ...@@ -5,87 +5,66 @@ namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, input_weights, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
heads, dropout_prob);
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -96,29 +75,28 @@ std::vector<torch::Tensor> bwd( ...@@ -96,29 +75,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward."); m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -55,24 +47,32 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -55,24 +47,32 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false); auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -132,40 +132,30 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -132,40 +132,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} }
...@@ -219,30 +209,17 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,30 +209,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -273,13 +250,17 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -273,13 +250,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -397,12 +378,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,12 +378,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
......
...@@ -5,111 +5,86 @@ namespace multihead_attn { ...@@ -5,111 +5,86 @@ namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob);
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
std::vector<torch::Tensor> fwd( #define CHECK_INPUT(x) \
bool use_mask, CHECK_CUDA(x); \
bool use_time_mask, CHECK_CONTIGUOUS(x)
bool is_training,
int heads, std::vector<torch::Tensor>
torch::Tensor const& inputs, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& input_weights, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& output_weights, torch::Tensor const &pad_mask, float dropout_prob) {
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
is_training, lyr_nrm_beta_weights, input_weights, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
inputs, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor>
std::vector<torch::Tensor> bwd( bwd(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& output_grads, torch::Tensor const &softmax_results,
torch::Tensor const& matmul2_results, torch::Tensor const &input_lin_results,
torch::Tensor const& dropout_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& inputs, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
...@@ -126,40 +101,42 @@ std::vector<torch::Tensor> bwd( ...@@ -126,40 +101,42 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
return bwd_cuda(heads, AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
input_lin_results, "Only HALF is supported");
lyr_nrm_results, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
lyr_nrm_mean, "Only HALF is supported");
lyr_nrm_invvar, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs, "Only HALF is supported");
lyr_nrm_gamma_weights, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
lyr_nrm_beta_weights, "Only BYTE is supported");
input_weights, AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
output_weights, "Only BYTE is supported");
dropout_mask,
dropout_add_mask, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_prob softmax_results, input_lin_results, lyr_nrm_results,
); lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
...@@ -170,4 +147,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -170,4 +147,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -58,7 +50,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -58,7 +50,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false); auto act_options = inputs.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
...@@ -67,22 +60,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -67,22 +60,29 @@ std::vector<torch::Tensor> fwd_cuda(
torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor output_lin_results= torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -90,16 +90,15 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -90,16 +90,15 @@ std::vector<torch::Tensor> fwd_cuda(
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()), static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs.data_ptr()), static_cast<const at::Half *>(inputs.data_ptr()),
static_cast<int>(batches), // n1 static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2 static_cast<int>(embed_dim), // n2
1.0e-5, 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -155,40 +154,30 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -155,40 +154,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} }
...@@ -245,57 +234,38 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -245,57 +234,38 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()), static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens,
total_tokens,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} else { } else {
apex_add_cuda<at::Half,float,uint32_t>( apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()), static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()), total_tokens);
total_tokens);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results,
lyr_nrm_results, softmax_results, dropout_results, dropout_mask, matmul2_results,
lyr_nrm_mean, dropout_add_mask, outputs};
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob) {
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
...@@ -331,13 +301,17 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -331,13 +301,17 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
torch::Tensor input_lin_grads = torch::empty_like(inputs); torch::Tensor input_lin_grads = torch::empty_like(inputs);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -347,11 +321,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -347,11 +321,10 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const *>(output_grads.data_ptr()),
static_cast<at::Half*>(dropout_add_grads.data_ptr()), static_cast<at::Half *>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const *>(dropout_add_mask.data_ptr()), total_tokens,
total_tokens,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
...@@ -463,12 +436,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -463,12 +436,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
...@@ -572,31 +543,23 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -572,31 +543,23 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half, float>(
static_cast<const half*>(input_lin_grads.data_ptr()), static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const*>(output_grads.data_ptr()), static_cast<half const *>(output_grads.data_ptr()),
static_cast<const float*>(lyr_nrm_mean.data_ptr()), static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float*>(lyr_nrm_invvar.data_ptr()), static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
inputs,
static_cast<int>(batches), // n1 static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2 static_cast<int>(embed_dim), // n2
static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()), static_cast<const half *>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5,
1.0e-5, static_cast<half *>(input_grads.data_ptr()),
static_cast<half*>(input_grads.data_ptr()), static_cast<half *>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()), static_cast<half *>(lyr_nrm_beta_grads.data_ptr()));
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads,
input_grads, input_weight_grads, output_weight_grads};
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment