Commit 9615983e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by hubertlu-tw
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent d11ddccf
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
dropout_prob);
}
} // end namespace mask_softmax_dropout
} // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
m.def("forward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
......@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
......@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace additive_mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
......@@ -11,33 +11,22 @@
const int UNROLL = 4;
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void
apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs,
uint8_t *mask, IndexType totalElements, accscalar_t p,
std::pair<uint64_t, uint64_t> seeds) {
accscalar_t pinv = accscalar_t(1) / p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
curand_init(seeds.first, idx, seeds.second, &state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
rand.x = rand.x <= p;
......@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = src[ii]*(&rand.x)[ii]*pinv;
outputs[li] = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
......@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_dropout_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
scalar_t *outputs, uint8_t *mask,
IndexType totalElements, accscalar_t p,
std::pair<uint64_t, uint64_t> seeds) {
accscalar_t pinv = accscalar_t(1) / p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
curand_init(seeds.first, idx, seeds.second, &state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
......@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
outputs[li] =
static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
......@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_add_kernel( scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs, scalar_t *outputs,
IndexType totalElements) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
......@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs,
}
}
template<typename scalar_t,
typename accscalar_t,
typename IndexType
>
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_masked_scale_kernel(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
scalar_t *outputs, uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
accscalar_t scale) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL)
{
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL];
scalar_t msk[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
......@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale *
static_cast<accscalar_t>(msk[ii]);
}
}
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_fused_dropout_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs,
uint8_t *mask, IndexType totalElements,
accscalar_t p) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
......@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif
}
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, p, rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError());
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_dropout_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
scalar_t *outputs, uint8_t *mask,
IndexType totalElements, accscalar_t p) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
......@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif
}
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, mask, totalElements, p,
rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError());
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
scalar_t *outputs, IndexType totalElements) {
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);
apex_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, totalElements);
C10_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_masked_scale_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs,
uint8_t const *mask, IndexType totalElements,
accscalar_t scale) {
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, scale);
C10_CUDA_CHECK(cudaGetLastError());
}
......@@ -5,103 +5,79 @@ namespace multihead_attn {
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
inputs_q, inputs_kv, input_weights_q, input_weights_kv,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemm_ex
......
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
......@@ -48,7 +39,7 @@ std::vector<torch::Tensor> fwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
......@@ -62,25 +53,34 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_q_results =
torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results =
torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr());
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void *k_lin_results_ptr =
static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -166,40 +166,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
......@@ -253,34 +243,24 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_q_results,
return {input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
......@@ -292,7 +272,7 @@ std::vector<torch::Tensor> bwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
......@@ -316,15 +296,20 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;
auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
......
......@@ -5,66 +5,49 @@ namespace multihead_attn {
namespace encdec_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
......@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
input_weights_q, input_weights_kv, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace cublas_gemmex
......@@ -195,4 +170,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace encdec_norm_add {
......@@ -64,7 +61,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
......@@ -73,23 +71,31 @@ std::vector<torch::Tensor> fwd_cuda(
torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_q_results =
torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results =
torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr());
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void *k_lin_results_ptr =
static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -97,16 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs_q.data_ptr()),
HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half *>(inputs_q.data_ptr()),
static_cast<int>(batches_q), // n1
static_cast<int>(embed_dim), // n2
1.0e-5,
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......@@ -187,40 +192,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
......@@ -276,25 +271,22 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens_q,
apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens_q,
(1.0f - dropout_prob));
} else {
apex_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
total_tokens_q);
apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()), total_tokens_q);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
return {lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_q_results,
......@@ -304,33 +296,22 @@ std::vector<torch::Tensor> fwd_cuda(
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
......@@ -343,7 +324,7 @@ std::vector<torch::Tensor> bwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
......@@ -370,16 +351,21 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_q_grads = torch::empty_like(inputs_q);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;
auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -505,12 +491,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
......@@ -683,15 +667,9 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads
};
return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads,
lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads,
output_weight_grads};
}
} // end namespace rocblas_gemmex
......
......@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
torch::Tensor const &padding_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
use_mask
? static_cast<const uint8_t *>(padding_mask.data_ptr())
: nullptr,
dropout_prob);
}
} // end namespace mask_softmax_dropout
......@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
......@@ -41,64 +34,55 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
......@@ -110,38 +94,31 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
} else{
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(padding_mask),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
} else {
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
}
//backward pass is completely in-place
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
#pragma once
//Philox CUDA.
// Philox CUDA.
class Philox {
public:
......@@ -15,28 +15,30 @@ public:
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
// 7-round philox
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
// STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
......@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
return a * b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
......@@ -85,6 +87,6 @@ private:
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
}
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
#include <cuda_fp16.h>
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
torch::Tensor const &input_biases,
torch::Tensor const &output_biases,
const half *pad_mask, float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
// torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -107,29 +82,26 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
bmm1_results, pad_mask, input_lin_results, inputs,
input_weights, output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
......@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self_bias_additive_mask {
......@@ -58,25 +55,33 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor bmm1_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda(
// Padded Softmax
bool softmax_success = false;
if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
attn_batches*q_seq_len*q_seq_len,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
softmax_success =
dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half *>(dropout_results_ptr),
(is_training)
? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
: nullptr,
reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
1.0f - dropout_prob, stream);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(
dropout_results_ptr), // this is actually softmax results, but
// making it consistent for the next function
reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
// Matmul2
......@@ -211,31 +213,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
return {input_lin_results, bmm1_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -266,13 +254,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
......
......@@ -5,94 +5,70 @@ namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -103,29 +79,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
......@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -58,24 +48,32 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
......@@ -223,30 +213,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
return {input_lin_results, softmax_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -277,13 +254,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
......@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
......
......@@ -5,87 +5,66 @@ namespace multihead_attn {
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
use_time_mask, is_training, heads, inputs, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -96,29 +75,28 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemm_ex
......@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -55,24 +47,32 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -132,40 +132,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
......@@ -219,30 +209,17 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
return {input_lin_results, softmax_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -273,13 +250,17 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -397,12 +378,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
......
......@@ -5,111 +5,86 @@ namespace multihead_attn {
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &pad_mask, float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -126,40 +101,42 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace cublas_gemmex
......@@ -170,4 +147,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -58,7 +50,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
......@@ -67,22 +60,29 @@ std::vector<torch::Tensor> fwd_cuda(
torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results= torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -90,16 +90,15 @@ std::vector<torch::Tensor> fwd_cuda(
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs.data_ptr()),
HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half *>(inputs.data_ptr()),
static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2
1.0e-5,
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......@@ -155,40 +154,30 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
......@@ -245,57 +234,38 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens,
apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens,
(1.0f - dropout_prob));
} else {
apex_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
total_tokens);
apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()), total_tokens);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results,
softmax_results, dropout_results, dropout_mask, matmul2_results,
dropout_add_mask, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -331,13 +301,17 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
torch::Tensor input_lin_grads = torch::empty_like(inputs);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -347,11 +321,10 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens,
apex_masked_scale_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_grads.data_ptr()),
static_cast<at::Half *>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const *>(dropout_add_mask.data_ptr()), total_tokens,
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
......@@ -463,12 +436,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
......@@ -572,31 +543,23 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
static_cast<const half*>(input_lin_grads.data_ptr()),
static_cast<half const*>(output_grads.data_ptr()),
static_cast<const float*>(lyr_nrm_mean.data_ptr()),
static_cast<const float*>(lyr_nrm_invvar.data_ptr()),
inputs,
HostLayerNormGradient<half, float>(
static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const *>(output_grads.data_ptr()),
static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2
static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),
1.0e-5,
static_cast<half*>(input_grads.data_ptr()),
static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
static_cast<const half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half *>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5,
static_cast<half *>(input_grads.data_ptr()),
static_cast<half *>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half *>(lyr_nrm_beta_grads.data_ptr()));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads
};
return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads,
input_weight_grads, output_weight_grads};
}
} // end namespace rocblas_gemmex
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment