Commit 1436a66a authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'origin/master' into IFU-master-2021-10-15

parents aee9f00d 08e88b1b
...@@ -145,3 +145,6 @@ dmypy.json ...@@ -145,3 +145,6 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
*.hip
*_hip.*
*hip*
...@@ -129,18 +129,18 @@ Note: Pytorch version recommended is >=1.5 for extension build. ...@@ -129,18 +129,18 @@ Note: Pytorch version recommended is >=1.5 for extension build.
### To install using python only build use the following command in apex folder: ### To install using python only build use the following command in apex folder:
``` ```
python3.6 setup.py install python setup.py install
``` ```
### To install using extensions enabled use the following command in apex folder: ### To install using extensions enabled use the following command in apex folder:
``` ```
python3.6 setup.py install --cpp_ext --cuda_ext python setup.py install --cpp_ext --cuda_ext
``` ```
### To install Apex on ROCm using ninja and without cloning the source ### To install Apex on ROCm using ninja and without cloning the source
``` ```
pip3.6 install ninja pip install ninja
pip3.6 install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git' pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
``` ```
### Linux ### Linux
......
...@@ -183,4 +183,4 @@ void ln_fwd_cuda( ...@@ -183,4 +183,4 @@ void ln_fwd_cuda(
assert(false && "Not implemented"); assert(false && "Not implemented");
} }
} }
\ No newline at end of file
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd( ...@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemm_ex
} // end namespace encdec } // end namespace encdec
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward."); m.def("forward", &multihead_attn::encdec::rocblas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward."); m.def("backward", &multihead_attn::encdec::rocblas_gemmex::bwd, "Encdec Multihead Attention Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -22,7 +25,7 @@ extern THCState *state; ...@@ -22,7 +25,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -86,9 +89,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -86,9 +89,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -96,20 +99,25 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -96,20 +99,25 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -117,17 +125,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -117,17 +125,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
CUDA_R_32F, k_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -146,6 +159,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -146,6 +159,9 @@ std::vector<torch::Tensor> fwd_cuda(
beta, beta,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
...@@ -208,10 +224,13 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -208,10 +224,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -219,20 +238,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,20 +238,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(outputs.data_ptr()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); algo,
solution_index,
flags));
return { return {
input_lin_q_results, input_lin_q_results,
...@@ -312,10 +333,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -312,10 +333,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -323,20 +342,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -323,20 +342,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -344,17 +368,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -344,17 +368,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -374,6 +403,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -374,6 +403,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -394,6 +426,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -394,6 +426,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -433,6 +468,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -433,6 +468,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
...@@ -453,10 +491,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -453,10 +491,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -464,21 +505,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,21 +505,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()), static_cast<void*>(input_q_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_q_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -486,20 +531,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -486,20 +531,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -507,21 +557,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -507,21 +557,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); algo,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); solution_index,
flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -529,18 +583,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -529,18 +583,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); solution_index,
flags));
return { return {
input_q_grads, input_q_grads,
...@@ -551,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -551,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec } // end namespace encdec
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd( ...@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -21,7 +25,7 @@ extern THCState *state; ...@@ -21,7 +25,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()), //static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, c_type,
output_lin_q_dim, output_lin_q_dim,
CUDA_R_32F, q_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
output_lin_q_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
CUDA_R_16F, c_type,
output_lin_kv_dim, output_lin_kv_dim,
CUDA_R_32F, k_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
output_lin_kv_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
...@@ -168,7 +180,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -168,7 +180,10 @@ std::vector<torch::Tensor> fwd_cuda(
beta, beta,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
...@@ -230,11 +245,14 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -230,11 +245,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta, beta,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -242,19 +260,23 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -242,19 +260,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_results.data_ptr()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half,float,uint32_t>(
...@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q); total_tokens_q);
} }
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -366,9 +386,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -366,9 +386,7 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_weight_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
...@@ -515,11 +552,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -515,11 +552,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta, beta,
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()), //static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()), static_cast<void*>(input_lin_q_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_lin_q_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_weight_q_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_kv_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
...@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
...@@ -635,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -635,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec_norm_add } // end namespace encdec_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
template<typename U> __device__ template<typename U> __device__
void cuWelfordOnlineSum( void cuWelfordOnlineSum(
const U curr, const U curr,
...@@ -84,9 +85,9 @@ void cuWelfordMuSigma2( ...@@ -84,9 +85,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB); U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB); U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB); U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
...@@ -122,8 +123,8 @@ void cuWelfordMuSigma2( ...@@ -122,8 +123,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2); sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0); sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
} }
} }
} }
...@@ -180,9 +181,9 @@ void cuWelfordMuSigma2( ...@@ -180,9 +181,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB); float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB); float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB); float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
...@@ -218,8 +219,8 @@ void cuWelfordMuSigma2( ...@@ -218,8 +219,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2); sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0); sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
} }
} }
} }
...@@ -227,9 +228,19 @@ void cuWelfordMuSigma2( ...@@ -227,9 +228,19 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) { template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v); return U(1) / sqrt(v);
} }
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) { template<> float rsqrt(float v) {
return rsqrtf(v); return rsqrtf(v);
} }
#endif
template<> double rsqrt(double v) { template<> double rsqrt(double v) {
return rsqrt(v); return rsqrt(v);
} }
...@@ -290,7 +301,7 @@ void cuApplyLayerNorm( ...@@ -290,7 +301,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensors are contiguous // 2) Tensors are contiguous
// //
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
U mu,sigma2; U mu,sigma2;
...@@ -529,7 +540,7 @@ void cuComputeGradInput( ...@@ -529,7 +540,7 @@ void cuComputeGradInput(
const T* gamma, const T* gamma,
T* grad_input) T* grad_input)
{ {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
...@@ -574,8 +585,8 @@ void cuComputeGradInput( ...@@ -574,8 +585,8 @@ void cuComputeGradInput(
} }
// intra-warp reductions // intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) { for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
} }
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
......
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd( ...@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#include <ATen/ATen.h> #undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -21,7 +24,7 @@ extern THCState *state; ...@@ -21,7 +24,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -48,8 +51,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -48,8 +51,8 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0; const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
...@@ -82,10 +85,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -82,10 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -93,18 +95,23 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -93,18 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
CUDA_R_32F, q_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
...@@ -123,7 +130,11 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -123,7 +130,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(bmm1_results_ptr), static_cast<half*>(bmm1_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (is_training) { if (is_training) {
...@@ -168,12 +179,15 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -168,12 +179,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -181,20 +195,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -181,20 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(outputs.data_ptr()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); algo,
solution_index,
flags));
return { return {
input_lin_results, input_lin_results,
...@@ -264,10 +280,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -264,10 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -275,19 +289,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -275,19 +289,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -295,17 +315,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -295,17 +315,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_weight_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -326,8 +351,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -326,8 +351,11 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -346,6 +374,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,6 +374,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -362,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -362,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches*q_seq_len/sequences, attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len, attn_batches*q_seq_len,
stream); stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -381,8 +412,11 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -381,8 +412,11 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -401,9 +435,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -401,9 +435,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -411,22 +449,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -411,22 +449,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()), static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
CUDA_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -434,20 +475,24 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -434,20 +475,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
...@@ -458,6 +503,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -458,6 +503,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd( ...@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -21,7 +24,7 @@ extern THCState *state; ...@@ -21,7 +24,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -80,11 +83,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -80,11 +83,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -92,17 +94,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -92,17 +94,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
CUDA_R_32F, q_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -122,7 +129,11 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -122,7 +129,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
...@@ -180,12 +191,15 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -180,12 +191,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -193,20 +207,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -193,20 +207,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(outputs.data_ptr()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); algo,
solution_index,
flags));
return { return {
input_lin_results, input_lin_results,
...@@ -275,10 +291,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -275,10 +291,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -286,19 +300,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -286,19 +300,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -306,17 +326,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -306,17 +326,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -337,6 +362,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -337,6 +362,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -357,6 +385,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -357,6 +385,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -385,7 +416,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -385,7 +416,10 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
beta, beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
...@@ -408,10 +442,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -408,10 +442,13 @@ std::vector<torch::Tensor> bwd_cuda(
beta, beta,
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -419,22 +456,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -419,22 +456,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()), static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
CUDA_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -442,20 +482,24 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,20 +482,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
...@@ -466,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -466,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd( ...@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemm_ex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward."); m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -21,7 +24,7 @@ extern THCState *state; ...@@ -21,7 +24,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -78,9 +81,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -78,9 +81,8 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -88,17 +90,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -88,17 +90,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -118,6 +125,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -118,6 +125,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
...@@ -179,10 +189,13 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -179,10 +189,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -190,19 +203,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -190,19 +203,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(outputs.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
embed_dim,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return { return {
input_lin_results, input_lin_results,
...@@ -270,11 +286,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -270,11 +286,9 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -282,20 +296,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -282,20 +296,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -303,17 +322,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -303,17 +322,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -333,6 +357,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -333,6 +357,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -353,6 +380,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -353,6 +380,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -392,6 +422,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -392,6 +422,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
...@@ -411,11 +444,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -411,11 +444,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta, beta,
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -423,20 +459,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -423,20 +459,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -444,18 +485,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -444,18 +485,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); solution_index,
flags));
return { return {
input_grads, input_grads,
...@@ -464,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd( ...@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -21,7 +25,7 @@ extern THCState *state; ...@@ -21,7 +25,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, c_type,
output_lin_dim,
q_lin_results_ptr,
d_type,
output_lin_dim, output_lin_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
...@@ -202,11 +214,14 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -202,11 +214,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta, beta,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -214,18 +229,24 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -214,18 +229,24 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_results.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>( apex_dropout_add_cuda<at::Half,float,uint32_t>(
...@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens); total_tokens);
} }
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -367,18 +389,23 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -367,18 +389,23 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_weight_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
...@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -455,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -455,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
beta, beta,
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches);
...@@ -475,11 +511,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -475,11 +511,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta, beta,
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()), //static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_lin_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
...@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
...@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment