Commit 9615983e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by hubertlu-tw
Browse files

Remove `THCState` from `apex/contrib/multihead_attn` (#1239)

* pass `self.mask_additive`

* clang-format

* removing THCState
parent d11ddccf
#include <torch/extension.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob);
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
std::vector<torch::Tensor> fwd( #define CHECK_INPUT(x) \
bool use_mask, CHECK_CUDA(x); \
bool is_training, CHECK_CONTIGUOUS(x)
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& output_grads, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& softmax_results, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
float dropout_prob AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
) "Only HALF is supported");
{ AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); "Only HALF is supported");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); // AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); // "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); dropout_prob);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward",
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob) {
torch::Tensor const& input, const int attn_batches = input.size(0);
const half* pad_mask, const int sequences = attn_batches / heads;
float dropout_prob const int q_seq_len = input.size(1);
) const int k_seq_len = q_seq_len;
{ const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; // There is no reason to use more than one stream as every kernel is
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = input.options().requires_grad(false); // by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& softmax_results, const int attn_batches = output_grads.size(0);
torch::Tensor const& dropout_mask, const int q_seq_len = output_grads.size(1);
float dropout_prob const int k_seq_len = q_seq_len;
) const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, // backward pass is completely in-place
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace additive_mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
...@@ -11,202 +11,170 @@ ...@@ -11,202 +11,170 @@
const int UNROLL = 4; const int UNROLL = 4;
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void
typename accscalar_t, apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs,
typename IndexType uint8_t *mask, IndexType totalElements, accscalar_t p,
> std::pair<uint64_t, uint64_t> seeds) {
__global__ void apex_fused_dropout_kernel(scalar_t const *inputs, accscalar_t pinv = accscalar_t(1) / p;
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init( curand_init(seeds.first, idx, seeds.second, &state);
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
float4 rand = curand_uniform4(&state); linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL]; float4 rand = curand_uniform4(&state);
rand.x = rand.x <= p; scalar_t src[UNROLL];
rand.y = rand.y <= p; rand.x = rand.x <= p;
rand.z = rand.z <= p; rand.y = rand.y <= p;
rand.w = rand.w <= p; rand.z = rand.z <= p;
rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
src[ii] = inputs[li]; src[ii] = inputs[li];
} }
} }
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = src[ii]*(&rand.x)[ii]*pinv; outputs[li] = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
__syncthreads(); __syncthreads();
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
typename accscalar_t, scalar_t const *add_inputs,
typename IndexType scalar_t *outputs, uint8_t *mask,
> IndexType totalElements, accscalar_t p,
__global__ void apex_dropout_add_kernel(scalar_t const *inputs, std::pair<uint64_t, uint64_t> seeds) {
scalar_t const *add_inputs, accscalar_t pinv = accscalar_t(1) / p;
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init( curand_init(seeds.first, idx, seeds.second, &state);
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
float4 rand = curand_uniform4(&state); linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL]; float4 rand = curand_uniform4(&state);
scalar_t add_src[UNROLL]; scalar_t src[UNROLL];
rand.x = rand.x <= p; scalar_t add_src[UNROLL];
rand.y = rand.y <= p; rand.x = rand.x <= p;
rand.z = rand.z <= p; rand.y = rand.y <= p;
rand.w = rand.w <= p; rand.z = rand.z <= p;
for (int ii = 0; ii < UNROLL; ii++) { rand.w = rand.w <= p;
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; for (int ii = 0; ii < UNROLL; ii++) {
if (li < totalElements) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
src[ii] = inputs[li]; if (li < totalElements) {
add_src[ii] = add_inputs[li]; src[ii] = inputs[li];
} add_src[ii] = add_inputs[li];
} }
for (int ii = 0; ii < UNROLL; ii++) { }
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; for (int ii = 0; ii < UNROLL; ii++) {
if (li < totalElements) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; if (li < totalElements) {
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1); accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii]; outputs[li] =
} static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
} mask[li] = (uint8_t)(&rand.x)[ii];
__syncthreads(); }
}
__syncthreads();
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, __global__ void apex_add_kernel(scalar_t const *inputs,
typename accscalar_t, scalar_t const *add_inputs, scalar_t *outputs,
typename IndexType IndexType totalElements) {
>
__global__ void apex_add_kernel( scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; IndexType rounded_size =
for (IndexType linearIndex = idx; ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
linearIndex < rounded_size; blockDim.x * gridDim.x * UNROLL;
linearIndex += gridDim.x * blockDim.x*UNROLL) { for (IndexType linearIndex = idx; linearIndex < rounded_size;
scalar_t src[UNROLL]; linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t add_src[UNROLL]; scalar_t src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) { scalar_t add_src[UNROLL];
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; for (int ii = 0; ii < UNROLL; ii++) {
if (li < totalElements) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
src[ii] = inputs[li]; if (li < totalElements) {
add_src[ii] = add_inputs[li]; src[ii] = inputs[li];
} add_src[ii] = add_inputs[li];
} }
for (int ii = 0; ii < UNROLL; ii++) { }
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; for (int ii = 0; ii < UNROLL; ii++) {
if (li < totalElements) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
outputs[li] = src[ii] + add_src[ii]; if (li < totalElements) {
} outputs[li] = src[ii] + add_src[ii];
} }
__syncthreads(); }
__syncthreads();
} }
} }
template<typename scalar_t, template <typename scalar_t, typename accscalar_t, typename IndexType>
typename accscalar_t, __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
typename IndexType scalar_t *outputs, uint8_t const *mask,
> IndexType totalElements,
__global__ void apex_masked_scale_kernel(scalar_t const *inputs, accscalar_t scale) {
scalar_t *outputs, IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
uint8_t const *mask, IndexType rounded_size =
IndexType totalElements, ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
accscalar_t scale blockDim.x * gridDim.x * UNROLL;
) for (IndexType linearIndex = idx; linearIndex < rounded_size;
{ linearIndex += gridDim.x * blockDim.x * UNROLL) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; scalar_t src[UNROLL];
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; scalar_t msk[UNROLL];
for (IndexType linearIndex = idx; for (int ii = 0; ii < UNROLL; ii++) {
linearIndex < rounded_size; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
linearIndex += gridDim.x * blockDim.x*UNROLL) if (li < totalElements) {
{ src[ii] = static_cast<scalar_t>(inputs[li]);
scalar_t src[UNROLL]; msk[ii] = static_cast<scalar_t>(mask[li]);
scalar_t msk[UNROLL]; }
for (int ii = 0; ii < UNROLL; ii++) { }
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; for (int ii = 0; ii < UNROLL; ii++) {
if (li < totalElements) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
src[ii] = static_cast<scalar_t>(inputs[li]); if (li < totalElements) {
msk[ii] = static_cast<scalar_t>(mask[li]); outputs[li] = static_cast<accscalar_t>(src[ii]) * scale *
} static_cast<accscalar_t>(msk[ii]);
} }
for (int ii = 0; ii < UNROLL; ii++) { }
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);
}
}
} }
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs,
typename accscalar_t, uint8_t *mask, IndexType totalElements,
typename IndexType accscalar_t p) {
>
void apex_fused_dropout_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state // number of times random will be generated per thread, to offset philox
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; // counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs, ...@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else #else
std::lock_guard<std::mutex> lock(gen.mutex()); std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset); rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif #endif
} }
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs); apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, p, rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
typename accscalar_t, scalar_t *outputs, uint8_t *mask,
typename IndexType IndexType totalElements, accscalar_t p) {
>
void apex_dropout_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state // number of times random will be generated per thread, to offset philox
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; // counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs, ...@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else #else
std::lock_guard<std::mutex> lock(gen.mutex()); std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset); rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif #endif
} }
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs); apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, mask, totalElements, p,
rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template < template <typename scalar_t, typename accscalar_t, typename IndexType>
typename scalar_t, void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
typename accscalar_t, scalar_t *outputs, IndexType totalElements) {
typename IndexType
>
void apex_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements); apex_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, totalElements);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
template<typename scalar_t, template <typename scalar_t, typename accscalar_t, typename IndexType>
typename accscalar_t, void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs,
typename IndexType uint8_t const *mask, IndexType totalElements,
> accscalar_t scale) {
void apex_masked_scale_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
int block_size = 256; int block_size = 256;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size); dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm =
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale); apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, scale);
C10_CUDA_CHECK(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
...@@ -5,145 +5,121 @@ namespace multihead_attn { ...@@ -5,145 +5,121 @@ namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
int heads, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& inputs_q, float dropout_prob) {
torch::Tensor const& inputs_kv, AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, input_weights_q, input_weights_kv, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs_q, dropout_prob);
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
return bwd_cuda( "Only HALF is supported");
heads, AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
input_lin_q_results, "Only HALF is supported");
input_lin_kv_results, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs_q, "Only HALF is supported");
inputs_kv, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
input_weights_q, "Only BYTE is supported");
input_weights_kv,
output_weights, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_mask, softmax_results, input_lin_q_results, input_lin_kv_results,
dropout_prob inputs_q, inputs_kv, input_weights_q, input_weights_kv,
); output_weights, dropout_mask, dropout_prob);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
......
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs_q.size(2);
const uint8_t* pad_mask, const int sequences = inputs_q.size(1);
float dropout_prob const int q_seq_len = inputs_q.size(0);
) const int k_seq_len = inputs_kv.size(0);
{ const int batches_q = sequences * q_seq_len;
const int embed_dim = inputs_q.size(2); const int batches_kv = sequences * k_seq_len;
const int sequences = inputs_q.size(1); const int head_dim = embed_dim / heads;
const int q_seq_len = inputs_q.size(0); const int output_lin_q_dim = embed_dim;
const int k_seq_len = inputs_kv.size(0); const int output_lin_kv_dim = 2 * embed_dim;
const int batches_q = sequences * q_seq_len; const int attn_batches = heads * sequences;
const int batches_kv = sequences * k_seq_len; const int lead_dim_q = attn_batches * head_dim;
const int head_dim = embed_dim / heads; const int lead_dim_kv = attn_batches * 2 * head_dim;
const int output_lin_q_dim = embed_dim; const int batch_stride_q = head_dim;
const int output_lin_kv_dim = 2 * embed_dim; const int batch_stride_kv = 2 * head_dim;
const int attn_batches = heads * sequences; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int lead_dim_q = attn_batches * head_dim; const float alpha = 1.0;
const int lead_dim_kv = attn_batches * 2 *head_dim; const float beta = 0.0;
const int batch_stride_q = head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs_q.options().requires_grad(false); // by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_q_results =
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_kv_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr()); void *k_lin_results_ptr =
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim); static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
...@@ -166,43 +166,33 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -166,43 +166,33 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -253,78 +243,73 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -253,78 +243,73 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_q_results,
input_lin_q_results, input_lin_kv_results,
input_lin_kv_results, softmax_results,
softmax_results, dropout_results,
dropout_results, dropout_mask,
dropout_mask, matmul2_results,
matmul2_results, outputs};
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv, const int embed_dim = inputs_q.size(2);
torch::Tensor const& input_weights_q, const int sequences = inputs_q.size(1);
torch::Tensor const& input_weights_kv, const int q_seq_len = inputs_q.size(0);
torch::Tensor const& output_weights, const int k_seq_len = inputs_kv.size(0);
torch::Tensor const& dropout_mask, const int batches_q = sequences * q_seq_len;
float dropout_prob const int batches_kv = sequences * k_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_q_dim = embed_dim;
const int embed_dim = inputs_q.size(2); const int output_lin_kv_dim = 2 * embed_dim;
const int sequences = inputs_q.size(1); const int attn_batches = heads * sequences;
const int q_seq_len = inputs_q.size(0); const int lead_dim_q = attn_batches * head_dim;
const int k_seq_len = inputs_kv.size(0); const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batches_q = sequences * q_seq_len; const int batch_stride_q = head_dim;
const int batches_kv = sequences * k_seq_len; const int batch_stride_kv = 2 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_q_dim = embed_dim; const float alpha = 1.0;
const int output_lin_kv_dim = 2 * embed_dim; const float beta = 0.0;
const int attn_batches = heads * sequences; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_q_grads = torch::empty_like(inputs_q); torch::Tensor input_q_grads = torch::empty_like(inputs_q);
torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);
torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim; auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr()); static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim; auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,12 +427,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
......
...@@ -5,194 +5,168 @@ namespace multihead_attn { ...@@ -5,194 +5,168 @@ namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& inputs_q, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &input_weights_kv,
torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob);
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob) {
torch::Tensor const& lyr_nrm_beta_weights, AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
torch::Tensor const& pad_mask, AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
float dropout_prob AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
) AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); "Only HALF is supported");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
is_training, input_weights_kv, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
inputs_q, : nullptr,
inputs_kv, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& lyr_nrm_beta_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_add_mask, AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
float dropout_prob AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
) AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
{ AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
return bwd_cuda( "Only HALF is supported");
heads, AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
input_lin_q_results, "Only HALF is supported");
input_lin_kv_results, AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
lyr_nrm_results, "Only HALF is supported");
lyr_nrm_mean, AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
lyr_nrm_invvar, "Only HALF is supported");
inputs_q, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs_kv, "Only HALF is supported");
lyr_nrm_gamma_weights, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
lyr_nrm_beta_weights, "Only BYTE is supported");
input_weights_q, AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
input_weights_kv, "Only BYTE is supported");
output_weights,
dropout_mask, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_add_mask, softmax_results, input_lin_q_results, input_lin_kv_results,
dropout_prob lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
); inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
input_weights_q, input_weights_kv, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
} // end namespace encdec_norm_add } // end namespace encdec_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
...@@ -61,52 +58,60 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -61,52 +58,60 @@ std::vector<torch::Tensor> fwd_cuda(
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs_q.options().requires_grad(false); // by ATen library code)
auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::Tensor input_lin_q_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_kv_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor softmax_results =
torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs_q, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr()); void *k_lin_results_ptr =
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim); static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()), static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs_q.data_ptr()), static_cast<const at::Half *>(inputs_q.data_ptr()),
static_cast<int>(batches_q), // n1 static_cast<int>(batches_q), // n1
static_cast<int>(embed_dim), // n2 static_cast<int>(embed_dim), // n2
1.0e-5, 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -187,41 +192,31 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -187,41 +192,31 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
...@@ -276,110 +271,101 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -276,110 +271,101 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()), static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens_q,
total_tokens_q, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} else { } else {
apex_add_cuda<at::Half,float,uint32_t>( apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()), static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()), total_tokens_q);
total_tokens_q);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {lyr_nrm_results,
lyr_nrm_results, lyr_nrm_mean,
lyr_nrm_mean, lyr_nrm_invvar,
lyr_nrm_invvar, input_lin_q_results,
input_lin_q_results, input_lin_kv_results,
input_lin_kv_results, softmax_results,
softmax_results, dropout_results,
dropout_results, dropout_mask,
dropout_mask, matmul2_results,
matmul2_results, dropout_add_mask,
dropout_add_mask, outputs};
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results,
torch::Tensor const& input_lin_q_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_q, torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const& inputs_kv, torch::Tensor const &dropout_add_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights, const int embed_dim = inputs_q.size(2);
torch::Tensor const& lyr_nrm_beta_weights, const int sequences = inputs_q.size(1);
torch::Tensor const& input_weights_q, const int q_seq_len = inputs_q.size(0);
torch::Tensor const& input_weights_kv, const int k_seq_len = inputs_kv.size(0);
torch::Tensor const& output_weights, const int batches_q = sequences * q_seq_len;
torch::Tensor const& dropout_mask, const int batches_kv = sequences * k_seq_len;
torch::Tensor const& dropout_add_mask, const int total_tokens_q = batches_q * embed_dim;
float dropout_prob const int head_dim = embed_dim / heads;
) const int output_lin_q_dim = embed_dim;
{ const int output_lin_kv_dim = 2 * embed_dim;
const int embed_dim = inputs_q.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs_q.size(1); const int lead_dim_q = attn_batches * head_dim;
const int q_seq_len = inputs_q.size(0); const int lead_dim_kv = attn_batches * 2 * head_dim;
const int k_seq_len = inputs_kv.size(0); const int batch_stride_q = head_dim;
const int batches_q = sequences * q_seq_len; const int batch_stride_kv = 2 * head_dim;
const int batches_kv = sequences * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int total_tokens_q = batches_q * embed_dim; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta = 0.0;
const int output_lin_q_dim = embed_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_q_grads = torch::empty_like(inputs_q); torch::Tensor input_q_grads = torch::empty_like(inputs_q);
torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);
torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);
torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);
torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor dropout_add_grads = torch::empty_like(output_grads); at::Tensor dropout_add_grads = torch::empty_like(output_grads);
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_kv_output_grads =
at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_q_grads = torch::empty_like(inputs_q);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim; auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr()); static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim; auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -505,12 +491,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -505,12 +491,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
...@@ -683,15 +667,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -683,15 +667,9 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads,
input_q_grads, lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads,
input_kv_grads, output_weight_grads};
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -4,14 +4,8 @@ ...@@ -4,14 +4,8 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
template <typename U>
template<typename U> __device__ __device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
count = count + U(1); count = count + U(1);
U delta = curr - mu; U delta = curr - mu;
U lmean = mu + delta / count; U lmean = mu + delta / count;
...@@ -20,15 +14,9 @@ void cuWelfordOnlineSum( ...@@ -20,15 +14,9 @@ void cuWelfordOnlineSum(
sigma2 = sigma2 + delta * delta2; sigma2 = sigma2 + delta * delta2;
} }
template<typename U> __device__ template <typename U>
void cuChanOnlineSum( __device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
const U muB, U &mu, U &sigma2, U &count) {
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
U delta = muB - mu; U delta = muB - mu;
U nA = count; U nA = count;
U nB = countB; U nB = countB;
...@@ -37,7 +25,7 @@ void cuChanOnlineSum( ...@@ -37,7 +25,7 @@ void cuChanOnlineSum(
if (nX > U(0)) { if (nX > U(0)) {
nA = nA / nX; nA = nA / nX;
nB = nB / nX; nB = nB / nX;
mu = nA*mu + nB*muB; mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else { } else {
mu = U(0); mu = U(0);
...@@ -45,16 +33,10 @@ void cuChanOnlineSum( ...@@ -45,16 +33,10 @@ void cuChanOnlineSum(
} }
} }
template<typename T, typename U> __device__ template <typename T, typename U>
void cuWelfordMuSigma2( __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
const T* __restrict__ vals, const int n2, const int i1, U &mu, U &sigma2,
const int n1, U *buf) {
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensor is contiguous // 2) Tensor is contiguous
...@@ -62,7 +44,7 @@ void cuWelfordMuSigma2( ...@@ -62,7 +44,7 @@ void cuWelfordMuSigma2(
// //
// compute variance and mean over n2 // compute variance and mean over n2
U count = U(0); U count = U(0);
mu= U(0); mu = U(0);
sigma2 = U(0); sigma2 = U(0);
if (i1 < n1) { if (i1 < n1) {
// one warp normalizes one n1 index, // one warp normalizes one n1 index,
...@@ -70,17 +52,17 @@ void cuWelfordMuSigma2( ...@@ -70,17 +52,17 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm // initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2; const T *lvals = vals + i1 * n2;
int l = 4*thrx; int l = 4 * thrx;
for (; l+3 < n2; l+=4*numx) { for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]); U curr = static_cast<U>(lvals[l + k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count); cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]); U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count); cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
} }
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
...@@ -93,23 +75,24 @@ void cuWelfordMuSigma2( ...@@ -93,23 +75,24 @@ void cuWelfordMuSigma2(
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
U* ubuf = (U*)buf; U *ubuf = (U *)buf;
U* ibuf = (U*)(ubuf + blockDim.y); U *ibuf = (U *)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) { for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared // upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset; const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu; ubuf[2 * wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2; ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count; ibuf[wrt_y] = count;
} }
__syncthreads(); __syncthreads();
// lower half merges // lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) { if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y]; U muB = ubuf[2 * threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1]; U sigma2B = ubuf[2 * threadIdx.y + 1];
U countB = ibuf[threadIdx.y]; U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
} }
__syncthreads(); __syncthreads();
} }
...@@ -120,7 +103,7 @@ void cuWelfordMuSigma2( ...@@ -120,7 +103,7 @@ void cuWelfordMuSigma2(
} }
__syncthreads(); __syncthreads();
mu = ubuf[0]; mu = ubuf[0];
sigma2 = ubuf[1]/U(n2); sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0, 32); mu = WARP_SHFL(mu, 0, 32);
...@@ -129,16 +112,10 @@ void cuWelfordMuSigma2( ...@@ -129,16 +112,10 @@ void cuWelfordMuSigma2(
} }
} }
template<> __device__ template <>
void cuWelfordMuSigma2( __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
const at::Half* __restrict__ vals, const int n1, const int n2, const int i1,
const int n1, float &mu, float &sigma2, float *buf) {
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensor is contiguous // 2) Tensor is contiguous
...@@ -146,7 +123,7 @@ void cuWelfordMuSigma2( ...@@ -146,7 +123,7 @@ void cuWelfordMuSigma2(
// //
// compute variance and mean over n2 // compute variance and mean over n2
float count = 0.0f; float count = 0.0f;
mu= float(0); mu = float(0);
sigma2 = float(0); sigma2 = float(0);
if (i1 < n1) { if (i1 < n1) {
...@@ -155,28 +132,28 @@ void cuWelfordMuSigma2( ...@@ -155,28 +132,28 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm // initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2; const at::Half *lvals = vals + i1 * n2;
int l = 8*thrx; int l = 8 * thrx;
if ((((size_t)lvals)&3) != 0) { if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment // 16 bit alignment
// first thread consumes first point // first thread consumes first point
if (thrx == 0) { if (thrx == 0) {
float curr = static_cast<float>(lvals[0]); float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count); cuWelfordOnlineSum(curr, mu, sigma2, count);
} }
++l; ++l;
} }
// at this point, lvals[l] are 32 bit aligned for all threads. // at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) { for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k+=2) { for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k))); float2 curr = __half22float2(*((__half2 *)(lvals + l + k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count); cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count); cuWelfordOnlineSum(curr.y, mu, sigma2, count);
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]); float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count); cuWelfordOnlineSum(curr, mu, sigma2, count);
} }
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
...@@ -189,23 +166,24 @@ void cuWelfordMuSigma2( ...@@ -189,23 +166,24 @@ void cuWelfordMuSigma2(
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
float* ubuf = (float*)buf; float *ubuf = (float *)buf;
float* ibuf = (float*)(ubuf + blockDim.y); float *ibuf = (float *)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) { for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared // upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset; const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu; ubuf[2 * wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2; ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count; ibuf[wrt_y] = count;
} }
__syncthreads(); __syncthreads();
// lower half merges // lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) { if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y]; float muB = ubuf[2 * threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1]; float sigma2B = ubuf[2 * threadIdx.y + 1];
float countB = ibuf[threadIdx.y]; float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
} }
__syncthreads(); __syncthreads();
} }
...@@ -216,7 +194,7 @@ void cuWelfordMuSigma2( ...@@ -216,7 +194,7 @@ void cuWelfordMuSigma2(
} }
__syncthreads(); __syncthreads();
mu = ubuf[0]; mu = ubuf[0];
sigma2 = ubuf[1]/float(n2); sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0, 32); mu = WARP_SHFL(mu, 0, 32);
...@@ -246,8 +224,9 @@ template<> double rsqrt(double v) { ...@@ -246,8 +224,9 @@ template<> double rsqrt(double v) {
} }
namespace { namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this // This is the un-specialized struct. Note that we prevent instantiation of
// struct by putting an undefined symbol in the function body so it won't compile. // this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T> // template <typename T>
// struct SharedMemory // struct SharedMemory
// { // {
...@@ -260,64 +239,50 @@ namespace { ...@@ -260,64 +239,50 @@ namespace {
// } // }
// }; // };
// https://github.com/NVIDIA/apex/issues/246 // https://github.com/NVIDIA/apex/issues/246
template <typename T> template <typename T> struct SharedMemory;
struct SharedMemory;
template <> template <> struct SharedMemory<float> {
struct SharedMemory <float> __device__ float *getPointer() {
{ extern __shared__ float s_float[];
__device__ float *getPointer() return s_float;
{ }
extern __shared__ float s_float[];
return s_float;
}
}; };
template <> template <> struct SharedMemory<double> {
struct SharedMemory <double> __device__ double *getPointer() {
{ extern __shared__ double s_double[];
__device__ double *getPointer() return s_double;
{ }
extern __shared__ double s_double[];
return s_double;
}
}; };
} } // namespace
template<typename T, typename U> __global__ template <typename T, typename U>
void cuApplyLayerNorm( __global__ void
T* __restrict__ output_vals, cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
U* __restrict__ mean, U *__restrict__ invvar, const T *__restrict__ vals,
U* __restrict__ invvar, const int n1, const int n2, const U epsilon,
const T* __restrict__ vals, const T *__restrict__ gamma, const T *__restrict__ beta) {
const int n1,
const int n2,
const U epsilon,
const T* __restrict__ gamma,
const T* __restrict__ beta
)
{
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensors are contiguous // 2) Tensors are contiguous
// //
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U *buf = shared.getPointer();
U mu,sigma2; U mu, sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
const T* lvals = vals + i1*n2; const T *lvals = vals + i1 * n2;
T* ovals = output_vals + i1*n2; T *ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon); U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) { if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) { for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]); U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i]; ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];
} }
} else { } else {
for (int i = thrx; i < n2; i+=numx) { for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]); U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<T>(c_invvar * (curr - mu)); ovals[i] = static_cast<T>(c_invvar * (curr - mu));
} }
...@@ -329,254 +294,230 @@ void cuApplyLayerNorm( ...@@ -329,254 +294,230 @@ void cuApplyLayerNorm(
} }
} }
template<typename T, typename U> __device__ template <typename T, typename U>
void cuLoadWriteStridedInputs( __device__ void cuLoadWriteStridedInputs(
const int i1_block, const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int thr_load_row_off, const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const int thr_load_col_off, const T *input, const T *dout, const int i1_end, const int n2,
const int i2_off, const U *__restrict__ mean, const U *__restrict__ invvar) {
const int row_stride, int i1 = i1_block + thr_load_row_off;
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) { if (i1 < i1_end) {
U curr_mean = mean[i1]; U curr_mean = mean[i1];
U curr_invvar = invvar[i1]; U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) { for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k; int i2 = i2_off + k;
int load_idx = i1*n2+i2; int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2<n2) { if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]); U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout; warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; warp_buf2[write_idx] =
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else { } else {
warp_buf1[write_idx] = U(0); warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0); warp_buf2[write_idx] = U(0);
} }
} }
} else { } else {
for (int k = 0; k < blockDim.y; ++k) { for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0); warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0); warp_buf2[write_idx] = U(0);
} }
} }
} }
template<typename T, typename U> __device__ template <typename T, typename U>
void cuLoadAddStridedInputs( __device__ void cuLoadAddStridedInputs(
const int i1_block, const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int thr_load_row_off, const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const int thr_load_col_off, const T *input, const T *dout, const int i1_end, const int n2,
const int i2_off, const U *__restrict__ mean, const U *__restrict__ invvar) {
const int row_stride, int i1 = i1_block + thr_load_row_off;
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) { if (i1 < i1_end) {
U curr_mean = mean[i1]; U curr_mean = mean[i1];
U curr_invvar = invvar[i1]; U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) { for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k; int i2 = i2_off + k;
int load_idx = i1*n2+i2; int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2<n2) { if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]); U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout; warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
} }
} }
} }
} }
template<typename T, typename U> __global__ template <typename T, typename U>
void cuComputePartGradGammaBeta( __global__ void cuComputePartGradGammaBeta(
const T* __restrict__ dout, const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const T* __restrict__ input, const int n2, const U *__restrict__ mean, const U *__restrict__ invvar,
const int n1, U epsilon, U *part_grad_gamma, U *part_grad_beta) {
const int n2, const int numsegs_n1 =
const U* __restrict__ mean, (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const U* __restrict__ invvar, const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
U epsilon, const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
U* part_grad_gamma, const int i1_beg_plus_one =
U* part_grad_beta) (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
{ const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); const int row_stride = blockDim.x + 1;
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; const int thr_load_row_off =
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
const int row_stride = blockDim.x+1; SharedMemory<U> shared;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; // blockDim.y + (blockDim.y -
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; // 1)*(blockDim.x/blockDim.y) elements
SharedMemory<U> shared; U *warp_buf1 = (U *)buf;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
U* warp_buf1 = (U*)buf; // compute partial sums from strided inputs
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // do this to increase number of loads in flight
// compute partial sums from strided inputs cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
// do this to increase number of loads in flight row_stride, warp_buf1, warp_buf2, input, dout,
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); i1_end, n2, mean, invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
} }
__syncthreads(); __syncthreads();
// inter-warp reductions }
// sum within each warp int i2 = blockIdx.x * blockDim.x + threadIdx.x;
U acc1 = U(0); if (threadIdx.y == 0 && i2 < n2) {
U acc2 = U(0); int row1 = threadIdx.y;
for (int k = 0; k < blockDim.y; ++k) { int row2 = threadIdx.y + 1;
int row1 = threadIdx.y + k*blockDim.y; int idx1 = row1 * row_stride + threadIdx.x;
int idx1 = row1*row_stride + threadIdx.x; int idx2 = row2 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1]; part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
acc2 += warp_buf2[idx1]; part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename T, typename U>
__global__ void
cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta,
const int part_size, const int n1, const int n2,
T *grad_gamma, T *grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U *buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
} }
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; // inter-warp reductions
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; const int nbsize3 = blockDim.x * blockDim.y / 2;
__syncthreads(); for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// sum all warps // top half write to shared memory
for (int offset = blockDim.y/2; offset > 1; offset /= 2) { if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) { if (threadIdx.y < offset) {
int row1 = threadIdx.y; const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
int row2 = threadIdx.y + offset; sum_gamma += buf[read_idx];
int idx1 = row1*row_stride + threadIdx.x; sum_beta += buf[read_idx + nbsize3];
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
} }
__syncthreads(); __syncthreads();
} }
int i2 = blockIdx.x * blockDim.x + threadIdx.x; // write out fully summed gradients
if (threadIdx.y == 0 && i2 < n2) { if (threadIdx.y == 0) {
int row1 = threadIdx.y; grad_gamma[i2] = sum_gamma;
int row2 = threadIdx.y + 1; grad_beta[i2] = sum_beta;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename T, typename U> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
T* grad_gamma,
T* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
} }
}
} }
template<typename T, typename U> __global__ template <typename T, typename U>
void cuComputeGradInput( __global__ void
const T* __restrict__ dout, cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
const T* __restrict__ dout_resid, const T *__restrict__ input, const int n1, const int n2,
const T* __restrict__ input, const U *__restrict__ mean, const U *__restrict__ invvar,
const int n1, U epsilon, const T *gamma, T *grad_input) {
const int n2, for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const T* gamma,
T* grad_input)
{
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
const U c_invvar = invvar[i1]; const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2; const T *k_input = input + i1 * n2;
const T* k_dout = dout + i1*n2; const T *k_dout = dout + i1 * n2;
const T* k_dout_resid = dout_resid + i1*n2; const T *k_dout_resid = dout_resid + i1 * n2;
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) { if (gamma != NULL) {
int l = 4*thrx; int l = 4 * thrx;
for (; l+3 < n2; l+=4*numx) { for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]); const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l+k]); const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * static_cast<U>(gamma[l+k]); sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 += c_loss * static_cast<U>(gamma[l+k]) * (c_h - c_mean) * c_invvar; sum_loss2 +=
c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]); const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]); const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * static_cast<U>(gamma[l]); sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar; sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
} }
} else { } else {
int l = 4*thrx; int l = 4 * thrx;
for (; l+3 < n2; l+=4*numx) { for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]); const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l+k]); const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss; sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]); const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]); const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss; sum_loss1 += c_loss;
...@@ -591,161 +532,121 @@ void cuComputeGradInput( ...@@ -591,161 +532,121 @@ void cuComputeGradInput(
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U *buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) { for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared // upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) { if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1; buf[2 * wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2; buf[2 * wrt_i + 1] = sum_loss2;
} }
__syncthreads(); __syncthreads();
// lower half merges // lower half merges
if (threadIdx.y < offset) { if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x; const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i]; sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2*read_i+1]; sum_loss2 += buf[2 * read_i + 1];
} }
__syncthreads(); __syncthreads();
} }
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1; buf[2 * threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2; buf[2 * threadIdx.x + 1] = sum_loss2;
} }
__syncthreads(); __syncthreads();
if (threadIdx.y !=0) { if (threadIdx.y != 0) {
sum_loss1 = buf[2*threadIdx.x]; sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1]; sum_loss2 = buf[2 * threadIdx.x + 1];
} }
} }
// all threads now have the two sums over l // all threads now have the two sums over l
U fH = (U)n2; U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar; U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2; T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) { if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) { for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]); const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]); const U c_loss = static_cast<U>(k_dout[l]);
const T c_resid= static_cast<T>(k_dout_resid[l]); const T c_resid = static_cast<T>(k_dout_resid[l]);
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]); U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
f_grad_input -= sum_loss1; f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1; f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid; k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;
} }
} else { } else {
for (int l = thrx; l < n2; l+=numx) { for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]); const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]); const U c_loss = static_cast<U>(k_dout[l]);
const T c_resid= static_cast<T>(k_dout_resid[l]); const T c_resid = static_cast<T>(k_dout_resid[l]);
U f_grad_input = fH * c_loss; U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1; f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1; f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid; k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;
} }
} }
} }
} }
template<typename T, typename U> template <typename T, typename U>
void HostApplyLayerNorm( void HostApplyLayerNorm(T *output, U *mean, U *invvar, const T *input, int n1,
T* output, int n2, double epsilon, const T *gamma, const T *beta) {
U* mean, auto stream = at::cuda::getCurrentCUDAStream().stream();
U* invvar, const dim3 threads(32, 4, 1);
const T* input, const uint64_t maxGridY =
int n1, at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
int n2, const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
double epsilon, int nshared =
const T* gamma, threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
const T* beta cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
) output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
} }
template<typename T, typename U> template <typename T, typename U>
void HostLayerNormGradient( void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean,
const T* dout, const U *invvar, const at::Tensor &input, int n1,
const T* dout_resid, int n2, const T *gamma, const T *beta,
const U* mean, double epsilon, T *grad_input, T *grad_gamma,
const U* invvar, T *grad_beta) {
const at::Tensor& input, auto stream = at::cuda::getCurrentCUDAStream().stream();
int n1,
int n2,
const T* gamma,
const T* beta,
double epsilon,
T* grad_input,
T* grad_gamma,
T* grad_beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) { if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j) // compute grad_gamma(j) and grad_beta(j)
const int part_size = 16; const int part_size = 16;
const dim3 threads2(32,4,1); const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a =
const int nshared2_b = threads2.x * threads2.y * sizeof(U); 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2_b = threads2.x * threads2.y * sizeof(U);
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_gamma = at::empty(
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( {part_size, n2},
dout, input.options().dtype(input.scalar_type() == at::ScalarType::Half
static_cast<T*>(input.data_ptr()), ? at::ScalarType::Float
n1,n2, : input.scalar_type()));
mean, at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
invvar, cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
U(epsilon), dout, static_cast<T *>(input.data_ptr()), n1, n2, mean, invvar,
static_cast<U*>(part_grad_gamma.data_ptr()), U(epsilon), static_cast<U *>(part_grad_gamma.data_ptr()),
static_cast<U*>(part_grad_beta.data_ptr())); static_cast<U *>(part_grad_beta.data_ptr()));
const dim3 threads3(32,8,1); const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U); const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>( cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
static_cast<U*>(part_grad_gamma.data_ptr()), static_cast<U *>(part_grad_gamma.data_ptr()),
static_cast<U*>(part_grad_beta.data_ptr()), static_cast<U *>(part_grad_beta.data_ptr()), part_size, n1, n2,
part_size, grad_gamma, grad_beta);
n1,n2, }
grad_gamma,
grad_beta);
}
// compute grad_input // compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 threads1(32,4,1); const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = const dim3 threads1(32, 4, 1);
threads1.y > 1 ? int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
threads1.y*threads1.x*sizeof(U) : cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
0; dout, dout_resid, static_cast<T *>(input.data_ptr()), n1, n2, mean,
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>( invvar, U(epsilon), gamma, grad_input);
dout,
dout_resid,
static_cast<T*>(input.data_ptr()),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
} }
...@@ -5,81 +5,66 @@ namespace multihead_attn { ...@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob);
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob);
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input, AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
torch::Tensor const& pad_mask, AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
float dropout_prob "Only HALF is supported");
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask,
torch::Tensor const& output_grads, torch::Tensor const &padding_mask, float dropout_prob) {
torch::Tensor const& softmax_results, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& padding_mask, AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
heads, use_mask
output_grads, ? static_cast<const uint8_t *>(padding_mask.data_ptr())
softmax_results, : nullptr,
dropout_mask, dropout_prob);
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // end namespace mask_softmax_dropout
...@@ -87,7 +72,8 @@ torch::Tensor bwd( ...@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob) {
const uint8_t* pad_mask, const int attn_batches = input.size(0);
float dropout_prob const int sequences = attn_batches / heads;
) const int q_seq_len = input.size(1);
{ const int k_seq_len = q_seq_len;
const int attn_batches = input.size(0); const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); // There is no reason to use more than one stream as every kernel is
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = input.options().requires_grad(false); // by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob) {
torch::Tensor const& dropout_mask, const int attn_batches = output_grads.size(0);
const uint8_t *padding_mask, const int q_seq_len = output_grads.size(1);
float dropout_prob const int k_seq_len = q_seq_len;
) const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, } else {
attn_batches*q_seq_len, stream); dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
} else{ false>(
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>( static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
static_cast<uint8_t const*>(padding_mask), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads, stream);
} }
//backward pass is completely in-place // backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
#pragma once #pragma once
//Philox CUDA. // Philox CUDA.
class Philox { class Philox {
public: public:
...@@ -15,28 +15,30 @@ public: ...@@ -15,28 +15,30 @@ public:
incr_n(offset / 4); incr_n(offset / 4);
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
if(STATE == 0) { if (STATE == 0) {
uint4 counter_ = counter; uint4 counter_ = counter;
uint2 key_ = key; uint2 key_ = key;
//7-round philox // 7-round philox
for(int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_); counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B); key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
} }
output = single_round(counter_, key_); output = single_round(counter_, key_);
incr(); incr();
} }
//return a float4 directly // return a float4 directly
//unsigned long ret; // unsigned long ret;
//switch(STATE) { // switch(STATE) {
// case 0: ret = output.x; break; // case 0: ret = output.x; break;
// case 1: ret = output.y; break; // case 1: ret = output.y; break;
// case 2: ret = output.z; break; // case 2: ret = output.z; break;
// case 3: ret = output.w; break; // case 3: ret = output.w; break;
//} //}
//STATE = (STATE + 1) % 4; // STATE = (STATE + 1) % 4;
return output; return output;
} }
private: private:
uint4 counter; uint4 counter;
uint4 output; uint4 output;
...@@ -67,7 +69,7 @@ private: ...@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { unsigned int *result_high) {
*result_high = __umulhi(a, b); *result_high = __umulhi(a, b);
return a*b; return a * b;
} }
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0; unsigned int hi0;
...@@ -84,7 +86,7 @@ private: ...@@ -84,7 +86,7 @@ private:
}; };
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f #define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
} }
#include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cuda_fp16.h>
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &input_biases,
torch::Tensor const& input_weights, torch::Tensor const &output_biases,
torch::Tensor const& output_weights, const half *pad_mask, float dropout_prob);
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, // torch::Tensor const& softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
// torch::Tensor const& softmax_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, // torch::Tensor const& input_biases,
torch::Tensor const& input_lin_results, // torch::Tensor const& output_biases,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_biases, torch::Tensor const& output_biases, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(use_mask , "no mask is not supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& inputs, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, bmm1_results, pad_mask, input_lin_results, inputs,
output_grads, input_weights, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -140,4 +112,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
...@@ -55,28 +52,36 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -55,28 +52,36 @@ std::vector<torch::Tensor> fwd_cuda(
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr()); void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr()); void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,27 +141,24 @@ std::vector<torch::Tensor> fwd_cuda(
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (is_training) { if (is_training) {
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>( softmax_success =
reinterpret_cast<half*>(dropout_results_ptr), dispatch_additive_masked_softmax_dropout<half, half, float>(
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr, reinterpret_cast<half *>(dropout_results_ptr),
reinterpret_cast<const half*>(bmm1_results_ptr), (is_training)
pad_mask, ? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
attn_batches*q_seq_len*q_seq_len, : nullptr,
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
k_seq_len, attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
attn_batches*q_seq_len/sequences, 1.0f - dropout_prob, stream);
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function reinterpret_cast<half *>(
reinterpret_cast<const half*>(bmm1_results_ptr), dropout_results_ptr), // this is actually softmax results, but
pad_mask, // making it consistent for the next function
k_seq_len, reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches*q_seq_len, attn_batches * q_seq_len / sequences);
attn_batches*q_seq_len/sequences);
} }
// Matmul2 // Matmul2
...@@ -211,73 +213,63 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,73 +213,63 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, bmm1_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& bmm1_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& pad_mask, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& input_lin_results, const int embed_dim = inputs.size(2);
torch::Tensor const& inputs, const int sequences = inputs.size(1);
torch::Tensor const& input_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& output_weights, const int k_seq_len = q_seq_len;
torch::Tensor const& dropout_mask, const int batches = sequences * q_seq_len;
float dropout_prob const int head_dim = embed_dim / heads;
) const int output_lin_dim = 3 * embed_dim;
{ const int attn_batches = heads * sequences;
const int embed_dim = inputs.size(2); const int lead_dim = attn_batches * 3 * head_dim;
const int sequences = inputs.size(1); const int batch_stride = 3 * head_dim;
const int q_seq_len = inputs.size(0); const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int k_seq_len = q_seq_len; const float alpha = 1.0;
const int batches = sequences * q_seq_len; const float beta = 0.0;
const int head_dim = embed_dim / heads; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,13 +488,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,127 +5,102 @@ namespace multihead_attn { ...@@ -5,127 +5,102 @@ namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, // torch::Tensor const& input_biases,
torch::Tensor const& inputs, // torch::Tensor const& output_biases,
torch::Tensor const& input_weights, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
int heads, torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
torch::Tensor const& inputs, torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_biases, torch::Tensor const& output_biases, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
use_time_mask, output_weights, input_biases, output_biases,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs, dropout_prob);
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
heads, softmax_results, input_lin_results, inputs, input_weights,
output_grads, output_weights, dropout_mask, dropout_prob);
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -136,4 +111,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor>
bool use_time_mask, fwd_cuda(bool use_time_mask, bool is_training, int heads,
bool is_training, torch::Tensor const &inputs, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const& inputs, torch::Tensor const &output_biases, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
torch::Tensor const& input_biases, const int sequences = inputs.size(1);
torch::Tensor const& output_biases, const int q_seq_len = inputs.size(0);
const uint8_t* pad_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta_zero = 0.0;
const int output_lin_dim = 3 * embed_dim; const float beta_one = 1.0;
const int attn_batches = heads * sequences; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; // There is no reason to use more than one stream as every kernel is
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -136,37 +134,29 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
...@@ -223,72 +213,63 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -223,72 +213,63 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta = 0.0;
const int output_lin_dim = 3 * embed_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim; auto k_lin_grads_ptr =
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -393,15 +374,13 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len,
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -503,13 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, input_weight_grads, output_weight_grads,
input_grads, input_bias_grads, output_bias_grads};
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -5,120 +5,98 @@ namespace multihead_attn { ...@@ -5,120 +5,98 @@ namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &input_weights,
bool is_training, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& inputs, torch::Tensor const& input_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, input_weights, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
heads, dropout_prob);
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
return bwd_cuda( AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
heads, "Only HALF is supported");
output_grads, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
matmul2_results, "Only HALF is supported");
dropout_results, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
softmax_results, "Only HALF is supported");
input_lin_results, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
inputs, "Only BYTE is supported");
input_weights,
output_weights, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_mask, softmax_results, input_lin_results, inputs, input_weights,
dropout_prob output_weights, dropout_mask, dropout_prob);
);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -129,4 +107,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward."); m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &input_weights,
int heads, torch::Tensor const &output_weights,
torch::Tensor const& inputs, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
const uint8_t* pad_mask, const int sequences = inputs.size(1);
float dropout_prob const int q_seq_len = inputs.size(0);
) const int k_seq_len = q_seq_len;
{ const int batches = sequences * q_seq_len;
const int embed_dim = inputs.size(2); const int head_dim = embed_dim / heads;
const int sequences = inputs.size(1); const int output_lin_dim = 3 * embed_dim;
const int q_seq_len = inputs.size(0); const int attn_batches = heads * sequences;
const int k_seq_len = q_seq_len; const int lead_dim = attn_batches * 3 * head_dim;
const int batches = sequences * q_seq_len; const int batch_stride = 3 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_dim = 3 * embed_dim; const float alpha = 1.0;
const int attn_batches = heads * sequences; const float beta = 0.0;
const int lead_dim = attn_batches * 3 * head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results =
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor dropout_results =
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
...@@ -132,43 +132,33 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -132,43 +132,33 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -219,67 +209,58 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,67 +209,58 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_results, softmax_results, dropout_results,
input_lin_results, dropout_mask, matmul2_results, outputs};
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const& softmax_results, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& input_lin_results, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
float dropout_prob const int batches = sequences * q_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int head_dim = embed_dim / heads; const float beta = 0.0;
const int output_lin_dim = 3 * embed_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -397,12 +378,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,12 +378,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
......
...@@ -5,169 +5,145 @@ namespace multihead_attn { ...@@ -5,169 +5,145 @@ namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob);
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
bool is_training, torch::Tensor const &lyr_nrm_beta_weights,
int heads, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &pad_mask, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights, AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
torch::Tensor const& lyr_nrm_beta_weights, AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
torch::Tensor const& pad_mask, AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(
use_time_mask, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
is_training, lyr_nrm_beta_weights, input_weights, output_weights,
heads, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
inputs, dropout_prob);
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor>
std::vector<torch::Tensor> bwd( bwd(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& output_grads, torch::Tensor const &softmax_results,
torch::Tensor const& matmul2_results, torch::Tensor const &input_lin_results,
torch::Tensor const& dropout_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& inputs, float dropout_prob) {
torch::Tensor const& lyr_nrm_gamma_weights, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& lyr_nrm_beta_weights, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_add_mask, AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
float dropout_prob AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
) AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
{ AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported"); "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only FLOAT is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
return bwd_cuda(heads, AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
input_lin_results, "Only HALF is supported");
lyr_nrm_results, AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
lyr_nrm_mean, "Only HALF is supported");
lyr_nrm_invvar, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs, "Only HALF is supported");
lyr_nrm_gamma_weights, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
lyr_nrm_beta_weights, "Only BYTE is supported");
input_weights, AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
output_weights, "Only BYTE is supported");
dropout_mask,
dropout_add_mask, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_prob softmax_results, input_lin_results, lyr_nrm_results,
); lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
} }
} // end namespace cublas_gemmex } // end namespace cublas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs,
bool is_training, torch::Tensor const &lyr_nrm_gamma_weights,
int heads, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& inputs, torch::Tensor const &input_weights,
torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const &output_weights,
torch::Tensor const& lyr_nrm_beta_weights, const uint8_t *pad_mask,
torch::Tensor const& input_weights, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs.size(2);
const uint8_t* pad_mask, const int sequences = inputs.size(1);
float dropout_prob const int q_seq_len = inputs.size(0);
) const int k_seq_len = q_seq_len;
{ const int batches = sequences * q_seq_len;
const int embed_dim = inputs.size(2); const int total_tokens = batches * embed_dim;
const int sequences = inputs.size(1); const int head_dim = embed_dim / heads;
const int q_seq_len = inputs.size(0); const int output_lin_dim = 3 * embed_dim;
const int k_seq_len = q_seq_len; const int attn_batches = heads * sequences;
const int batches = sequences * q_seq_len; const int lead_dim = attn_batches * 3 * head_dim;
const int total_tokens = batches * embed_dim; const int batch_stride = 3 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_dim = 3 * embed_dim; const float alpha = 1.0;
const int attn_batches = heads * sequences; const float beta = 0.0;
const int lead_dim = attn_batches * 3 * head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs.options().requires_grad(false); // by ATen library code)
auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor output_lin_results= torch::empty_like(inputs, act_options); torch::Tensor dropout_results =
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options); torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim); void *k_lin_results_ptr = static_cast<void *>(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()), static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs.data_ptr()), static_cast<const at::Half *>(inputs.data_ptr()),
static_cast<int>(batches), // n1 static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2 static_cast<int>(embed_dim), // n2
1.0e-5, 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -155,41 +154,31 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -155,41 +154,31 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
...@@ -245,99 +234,84 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -245,99 +234,84 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()), static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens,
total_tokens, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} else { } else {
apex_add_cuda<at::Half,float,uint32_t>( apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs.data_ptr()), static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()), static_cast<at::Half *>(outputs.data_ptr()), total_tokens);
total_tokens);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results,
lyr_nrm_results, softmax_results, dropout_results, dropout_mask, matmul2_results,
lyr_nrm_mean, dropout_add_mask, outputs};
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_results,
torch::Tensor const& softmax_results, torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const& input_lin_results, torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const& lyr_nrm_results, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_mean, torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const& lyr_nrm_invvar, torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const& inputs, torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
torch::Tensor const& lyr_nrm_gamma_weights, float dropout_prob) {
torch::Tensor const& lyr_nrm_beta_weights, const int embed_dim = inputs.size(2);
torch::Tensor const& input_weights, const int sequences = inputs.size(1);
torch::Tensor const& output_weights, const int q_seq_len = inputs.size(0);
torch::Tensor const& dropout_mask, const int k_seq_len = q_seq_len;
torch::Tensor const& dropout_add_mask, const int batches = sequences * q_seq_len;
float dropout_prob const int total_tokens = batches * embed_dim;
) const int head_dim = embed_dim / heads;
{ const int output_lin_dim = 3 * embed_dim;
const int embed_dim = inputs.size(2); const int attn_batches = heads * sequences;
const int sequences = inputs.size(1); const int lead_dim = attn_batches * 3 * head_dim;
const int q_seq_len = inputs.size(0); const int batch_stride = 3 * head_dim;
const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int batches = sequences * q_seq_len; const float alpha = 1.0;
const int total_tokens = batches * embed_dim; const float beta = 0.0;
const int head_dim = embed_dim / heads; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);
torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
torch::Tensor dropout_add_grads = torch::empty_like(output_grads); torch::Tensor dropout_add_grads = torch::empty_like(output_grads);
torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); torch::Tensor output_lin_grads = torch::empty_like(matmul2_results);
torch::Tensor matmul2_grads = torch::empty_like(dropout_results); torch::Tensor matmul2_grads = torch::empty_like(dropout_results);
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
torch::Tensor input_lin_grads = torch::empty_like(inputs); torch::Tensor input_lin_grads = torch::empty_like(inputs);
auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim; auto k_lin_results_ptr =
auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim; static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()); static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2*head_dim; auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -346,14 +320,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,14 +320,13 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const *>(output_grads.data_ptr()),
static_cast<at::Half*>(dropout_add_grads.data_ptr()), static_cast<at::Half *>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const *>(dropout_add_mask.data_ptr()), total_tokens,
total_tokens, (1.0 / (1.0 - dropout_prob)));
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -463,12 +436,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -463,12 +436,10 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
...@@ -572,31 +543,23 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -572,31 +543,23 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half, float>(
static_cast<const half*>(input_lin_grads.data_ptr()), static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const*>(output_grads.data_ptr()), static_cast<half const *>(output_grads.data_ptr()),
static_cast<const float*>(lyr_nrm_mean.data_ptr()), static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float*>(lyr_nrm_invvar.data_ptr()), static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
inputs, static_cast<int>(batches), // n1
static_cast<int>(batches), // n1 static_cast<int>(embed_dim), // n2
static_cast<int>(embed_dim), // n2 static_cast<const half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()), static_cast<const half *>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5,
static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()), static_cast<half *>(input_grads.data_ptr()),
1.0e-5, static_cast<half *>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(input_grads.data_ptr()), static_cast<half *>(lyr_nrm_beta_grads.data_ptr()));
static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads,
input_grads, input_weight_grads, output_weight_grads};
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads
};
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment