Unverified Commit 29b36315 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Cherry-picked the commit from upstream for faster --fast_multihead_attn build (#76)



* Faster `--fast_multihead_attn` build (#1245)

* merge .so files

* odr

* fix build

* update import

* apply psf/black with max line length of 120

* update

* fix

* update

* build fixed again but undefined symbol again

* fix 2, still layer norm grad is undefined

* remove unused cpp files

* without layer_norm.cuh, import works

* import fast_multihead_attn works...

but why? Was unnecessary `#include "layer_norm.cuh"` was the culprit
causing .shared objects not to be able to link `HostApplyLayerNorm` and
`HostLayerNormGradient`?

* clean up layer norm

* Fix some bugs
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 5ecad142
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#pragma once #pragma once
#include <multihead_attn/philox.h> #include <multihead_attn/philox.cuh>
#include <fmha.h> #include <fmha.h>
#include <fmha/utils.h> #include <fmha/utils.h>
......
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
}
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
dropout_prob);
}
} // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "softmax.h" #include "softmax.cuh"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
...@@ -27,7 +27,7 @@ std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads, ...@@ -27,7 +27,7 @@ std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
const int sequences = attn_batches / heads; const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1); const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len; const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -86,7 +86,7 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, ...@@ -86,7 +86,7 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1); const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len; const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
...@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, ...@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
} }
} // namespace additive_mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // namespace fused_softmax } // namespace fused_softmax
} // namespace multihead_attn } // namespace multihead_attn
\ No newline at end of file
#pragma once
#include <ATen/ATen.h> #include <ATen/ATen.h>
#if !defined(NEW_GENERATOR_PATH) #if !defined(NEW_GENERATOR_PATH)
...@@ -9,7 +10,9 @@ ...@@ -9,7 +10,9 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h> #include <curand_kernel.h>
const int UNROLL = 4; namespace {
constexpr int UNROLL = 4;
} // namespace
template <typename scalar_t, typename accscalar_t, typename IndexType> template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void __global__ void
......
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
inputs_q, inputs_kv, input_weights_q, input_weights_kv,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemm_ex
} // end namespace encdec
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::rocblas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::rocblas_gemmex::bwd, "Encdec Multihead Attention Backward.");
}
...@@ -11,10 +11,9 @@ ...@@ -11,10 +11,9 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "layer_norm.h" #include "softmax.cuh"
#include "softmax.h" #include "strided_batched_gemm.cuh"
#include "strided_batched_gemm.h"
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
...@@ -86,6 +85,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -86,6 +85,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
...@@ -110,8 +111,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -110,8 +111,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
...@@ -136,8 +137,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -136,8 +137,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
...@@ -161,7 +162,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -161,7 +162,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -215,7 +216,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -215,7 +216,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches, attn_batches,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -239,8 +240,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -239,8 +240,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -317,6 +318,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -317,6 +318,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
...@@ -350,8 +353,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -350,8 +353,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
...@@ -376,8 +379,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -376,8 +379,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -401,7 +404,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -401,7 +404,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -424,7 +427,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -424,7 +427,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches, attn_batches,
flags); flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -464,7 +467,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,7 +467,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
attn_batches, attn_batches,
flags); flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -487,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -487,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches, attn_batches,
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -511,8 +514,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -511,8 +514,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
...@@ -537,8 +540,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -537,8 +540,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
...@@ -563,8 +566,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -563,8 +566,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
...@@ -589,8 +592,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -589,8 +592,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
......
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace encdec_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
input_weights_q, input_weights_kv, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace cublas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "layer_norm.h" #include "layer_norm.cuh"
#include "softmax.h" #include "softmax.cuh"
#include "strided_batched_gemm.h" #include "strided_batched_gemm.cuh"
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
...@@ -101,6 +101,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -101,6 +101,8 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half, float>( HostApplyLayerNorm<at::Half, float>(
...@@ -122,23 +124,23 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -122,23 +124,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()), //static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
c_type, rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim, output_lin_q_dim,
q_lin_results_ptr, rocblas_datatype_f32_r /*compute_type*/,
d_type, rocblas_gemm_algo_standard /*algo*/,
output_lin_q_dim, 0 /*solution_index*/,
compute_type, flags));
algo,
solution_index,
flags));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -149,22 +151,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -149,22 +151,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
c_type, rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim, output_lin_kv_dim,
k_lin_results_ptr, k_lin_results_ptr,
d_type, rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim, output_lin_kv_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
b_layout_n, b_layout_n,
...@@ -182,11 +184,11 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -182,11 +184,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -237,11 +239,11 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -237,11 +239,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches, attn_batches,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -252,22 +254,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -252,22 +254,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_lin_results.data_ptr()), rocblas_datatype_f32_r /*compute_type*/,
d_type, rocblas_gemm_algo_standard /*algo*/,
embed_dim, 0 /*solution_index*/,
compute_type, flags));
algo,
solution_index,
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
...@@ -371,6 +373,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -371,6 +373,8 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -400,22 +404,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -400,22 +404,22 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -426,22 +430,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -426,22 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f32_r /*compute_type*/,
d_type, rocblas_gemm_algo_standard /*algo*/,
embed_dim, 0 /*solution_index*/,
compute_type, flags));
algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -460,11 +464,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -460,11 +464,11 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -483,11 +487,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -483,11 +487,11 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches, attn_batches,
flags); flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -523,11 +527,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -523,11 +527,11 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
attn_batches, attn_batches,
flags); flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -546,11 +550,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -546,11 +550,11 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches, attn_batches,
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -561,23 +565,23 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -561,23 +565,23 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
b_type, rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()), //static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()), static_cast<void*>(input_lin_q_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()), rocblas_datatype_f32_r /*compute_type*/,
d_type, rocblas_gemm_algo_standard /*algo*/,
embed_dim, 0 /*solution_index*/,
compute_type, flags));
algo,
solution_index,
flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -588,22 +592,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -588,22 +592,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
b_type, rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()), rocblas_datatype_f32_r /*compute_type*/,
d_type, rocblas_gemm_algo_standard /*algo*/,
embed_dim, 0 /*solution_index*/,
compute_type, flags));
algo,
solution_index,
flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -614,22 +618,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -614,22 +618,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
b_type, rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -640,22 +644,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -640,22 +644,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
b_type, rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()), rocblas_datatype_f32_r /*compute_type*/,
d_type, rocblas_gemm_algo_standard /*algo*/,
embed_dim, 0 /*solution_index*/,
compute_type, flags));
algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
......
#include "ATen/ATen.h" #pragma once
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/DeviceUtils.cuh>
namespace {
template <typename U> template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { __device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
count = count + U(1); count = count + U(1);
...@@ -211,19 +212,15 @@ template<typename U> U rsqrt(U v) { ...@@ -211,19 +212,15 @@ template<typename U> U rsqrt(U v) {
//} //}
#if defined __HIP_PLATFORM_HCC__ #if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) { __device__ float rsqrt(float v) { return rsqrtf(v); }
return rsqrtf(v);
}
#else #else
template<> float rsqrt(float v) { template<> float rsqrt(float v) { return rsqrtf(v); }
return rsqrtf(v);
}
#endif #endif
template<> double rsqrt(double v) { template<> double rsqrt(double v) { return rsqrt(v); }
return rsqrt(v); // template <typename U> __device__ U rsqrt(U v) { return U(1) / sqrt(v); }
} // template <> __device__ float rsqrt(float v) { return rsqrtf(v); }
// template <> __device__ double rsqrt(double v) { return rsqrt(v); }
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of // This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't // this struct by putting an undefined symbol in the function body so it won't
// compile. // compile.
...@@ -240,7 +237,6 @@ namespace { ...@@ -240,7 +237,6 @@ namespace {
// }; // };
// https://github.com/NVIDIA/apex/issues/246 // https://github.com/NVIDIA/apex/issues/246
template <typename T> struct SharedMemory; template <typename T> struct SharedMemory;
template <> struct SharedMemory<float> { template <> struct SharedMemory<float> {
__device__ float *getPointer() { __device__ float *getPointer() {
extern __shared__ float s_float[]; extern __shared__ float s_float[];
...@@ -254,7 +250,6 @@ template <> struct SharedMemory<double> { ...@@ -254,7 +250,6 @@ template <> struct SharedMemory<double> {
return s_double; return s_double;
} }
}; };
} // namespace
template <typename T, typename U> template <typename T, typename U>
__global__ void __global__ void
...@@ -473,6 +468,7 @@ cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, ...@@ -473,6 +468,7 @@ cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta,
} }
} }
template <typename T, typename U> template <typename T, typename U>
__global__ void __global__ void
cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
...@@ -650,3 +646,4 @@ void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean, ...@@ -650,3 +646,4 @@ void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean,
dout, dout_resid, static_cast<T *>(input.data_ptr()), n1, n2, mean, dout, dout_resid, static_cast<T *>(input.data_ptr()), n1, n2, mean,
invvar, U(epsilon), gamma, grad_input); invvar, U(epsilon), gamma, grad_input);
} }
} // namespace
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
torch::Tensor const &padding_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
use_mask
? static_cast<const uint8_t *>(padding_mask.data_ptr())
: nullptr,
dropout_prob);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "softmax.h" #include "softmax.cuh"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
......
This diff is collapsed.
#pragma once #pragma once
// Philox CUDA. // Philox CUDA.
namespace {
class Philox { class Philox {
public: public:
__device__ inline Philox(unsigned long long seed, __device__ inline Philox(unsigned long long seed,
...@@ -85,8 +87,10 @@ private: ...@@ -85,8 +87,10 @@ private:
static const unsigned long kPhiloxSB = 0xCD9E8D57; static const unsigned long kPhiloxSB = 0xCD9E8D57;
}; };
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f constexpr float M_RAN_INVM32 = 2.3283064e-10f;
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32); x.w * M_RAN_INVM32);
} }
} // namespace
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
torch::Tensor const &input_biases,
torch::Tensor const &output_biases,
const half *pad_mask, float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
// torch::Tensor const& softmax_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
bmm1_results, pad_mask, input_lin_results, inputs,
input_weights, output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
...@@ -11,10 +11,9 @@ ...@@ -11,10 +11,9 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "layer_norm.h" #include "softmax.cuh"
#include "softmax.h" #include "strided_batched_gemm.cuh"
#include "strided_batched_gemm.h"
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
...@@ -87,6 +86,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -87,6 +86,8 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
...@@ -112,8 +113,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -112,8 +113,8 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
...@@ -137,7 +138,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -137,7 +138,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -183,7 +184,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -183,7 +184,7 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches, attn_batches,
flags); flags);
outputs.copy_(output_biases); outputs.copy_(output_biases);
...@@ -209,8 +210,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -209,8 +210,8 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -272,6 +273,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -272,6 +273,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
...@@ -305,8 +308,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -305,8 +308,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
...@@ -331,9 +334,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -331,9 +334,9 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -357,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -357,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -380,7 +383,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -380,7 +383,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
...@@ -389,13 +392,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -389,13 +392,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half* const>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()), reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()), reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len/sequences, attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len, attn_batches*q_seq_len,
stream); stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -418,7 +421,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -418,7 +421,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -441,7 +444,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,7 +444,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -465,8 +468,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -465,8 +468,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
...@@ -491,8 +494,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -491,8 +494,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
...@@ -503,5 +506,5 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -503,5 +506,5 @@ std::vector<torch::Tensor> bwd_cuda(
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self_bias_additive_mask
} // end namespace multihead_attn } // end namespace multihead_attn
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
...@@ -11,10 +11,9 @@ ...@@ -11,10 +11,9 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "layer_norm.h" #include "softmax.cuh"
#include "softmax.h" #include "strided_batched_gemm.cuh"
#include "strided_batched_gemm.h"
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
...@@ -79,6 +78,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -79,6 +78,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
...@@ -104,8 +105,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -104,8 +105,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
...@@ -129,7 +130,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -129,7 +130,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -183,7 +184,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -183,7 +184,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches, attn_batches,
flags); flags);
outputs.copy_(output_biases); outputs.copy_(output_biases);
...@@ -209,8 +210,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -209,8 +210,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -272,6 +273,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -272,6 +273,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
...@@ -305,8 +308,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -305,8 +308,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
...@@ -331,8 +334,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -331,8 +334,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
...@@ -357,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -357,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -380,7 +383,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -380,7 +383,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
...@@ -413,7 +416,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -413,7 +416,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -436,7 +439,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -436,7 +439,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -459,8 +462,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -459,8 +462,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
...@@ -485,8 +488,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -485,8 +488,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
......
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask, is_training, heads, inputs, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemm_ex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
}
...@@ -11,10 +11,9 @@ ...@@ -11,10 +11,9 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "layer_norm.h" #include "softmax.cuh"
#include "softmax.h" #include "strided_batched_gemm.cuh"
#include "strided_batched_gemm.h"
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
...@@ -78,6 +77,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -78,6 +77,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
...@@ -102,8 +103,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -102,8 +103,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
...@@ -127,7 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -127,7 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -181,7 +182,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -181,7 +182,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches, attn_batches,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -205,8 +206,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -205,8 +206,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -267,6 +268,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -267,6 +268,8 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -301,8 +304,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -301,8 +304,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
...@@ -327,8 +330,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,8 +330,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -352,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -352,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -375,7 +378,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -375,7 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -415,7 +418,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -415,7 +418,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -438,7 +441,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -438,7 +441,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -462,8 +465,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -462,8 +465,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
...@@ -488,8 +491,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -488,8 +491,8 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &pad_mask, float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace cublas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment