Unverified Commit 9f899769 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Merge pull request #56 from ROCmSoftwarePlatform/dev/hubertlu/multihead_attn

Enable multihead atten
parents 325246e4 62f06964
......@@ -4,3 +4,6 @@ build
docs/build
*~
__pycache__
*.hip
*_hip.*
*hip*
......@@ -183,4 +183,4 @@ void ln_fwd_cuda(
assert(false && "Not implemented");
}
}
\ No newline at end of file
}
......@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace encdec {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemm_ex
} // end namespace encdec
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward.");
m.def("forward", &multihead_attn::encdec::rocblas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::rocblas_gemmex::bwd, "Encdec Multihead Attention Backward.");
}
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......@@ -22,7 +25,7 @@ extern THCState *state;
namespace multihead_attn {
namespace encdec {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -86,9 +89,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
......@@ -96,20 +99,25 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
......@@ -117,17 +125,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_kv_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
......@@ -146,6 +159,9 @@ std::vector<torch::Tensor> fwd_cuda(
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
......@@ -208,10 +224,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -219,20 +238,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return {
input_lin_q_results,
......@@ -312,10 +333,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -323,20 +342,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -344,17 +368,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
......@@ -374,6 +403,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
......@@ -394,6 +426,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -433,6 +468,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches);
// Matmul1 Dgrad2
......@@ -453,10 +491,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches);
// Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -464,21 +505,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -486,20 +531,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -507,21 +557,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -529,18 +583,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return {
input_q_grads,
......@@ -551,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace encdec
} // end namespace multihead_attn
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace encdec_norm_add {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......@@ -21,7 +25,7 @@ extern THCState *state;
namespace multihead_attn {
namespace encdec_norm_add {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
......@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
......@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
CUDA_R_16F,
c_type,
output_lin_q_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
q_lin_results_ptr,
d_type,
output_lin_q_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
......@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
CUDA_R_16F,
c_type,
output_lin_kv_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
k_lin_results_ptr,
d_type,
output_lin_kv_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
......@@ -168,7 +180,10 @@ std::vector<torch::Tensor> fwd_cuda(
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
......@@ -230,11 +245,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -242,19 +260,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_lin_results.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>(
......@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q);
}
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
lyr_nrm_mean,
......@@ -366,9 +386,7 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()),
......@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_weight_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
......@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
......@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches);
// Matmul1 Dgrad2
......@@ -515,11 +552,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches);
// Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
b_type,
output_lin_q_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_lin_q_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
b_type,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_weight_q_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F,
b_type,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_kv_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F,
b_type,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_weight_kv_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
......@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
......@@ -635,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
......@@ -4,6 +4,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
......@@ -84,9 +85,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
......@@ -122,8 +123,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
}
}
}
......@@ -180,9 +181,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
......@@ -218,8 +219,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
}
}
}
......@@ -227,9 +228,19 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v);
}
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) {
return rsqrtf(v);
}
#endif
template<> double rsqrt(double v) {
return rsqrt(v);
}
......@@ -290,7 +301,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
......@@ -529,7 +540,7 @@ void cuComputeGradInput(
const T* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
......@@ -574,8 +585,8 @@ void cuComputeGradInput(
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
}
// inter-warp reductions
if (blockDim.y > 1) {
......
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......
......@@ -4,7 +4,7 @@
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
//#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -21,7 +24,7 @@ extern THCState *state;
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -48,8 +51,8 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
......@@ -82,10 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -93,18 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
......@@ -123,7 +130,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
bool softmax_success = false;
if (is_training) {
......@@ -168,12 +179,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
outputs.copy_(output_biases);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -181,20 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return {
input_lin_results,
......@@ -264,10 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -275,19 +289,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -295,17 +315,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -326,8 +351,11 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -346,6 +374,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -362,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -381,8 +412,11 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -401,9 +435,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -411,22 +449,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -434,20 +475,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......@@ -458,6 +503,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......@@ -21,7 +24,7 @@ extern THCState *state;
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -80,11 +83,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -92,17 +94,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
......@@ -122,7 +129,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
......@@ -180,12 +191,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
outputs.copy_(output_biases);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -193,20 +207,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return {
input_lin_results,
......@@ -275,10 +291,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -286,19 +300,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -306,17 +326,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -337,6 +362,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
......@@ -357,6 +385,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -385,7 +416,10 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
......@@ -408,10 +442,13 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -419,22 +456,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -442,20 +482,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......@@ -466,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace self {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemm_ex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.");
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
}
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......@@ -21,7 +24,7 @@ extern THCState *state;
namespace multihead_attn {
namespace self {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -78,9 +81,8 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -88,17 +90,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
......@@ -118,6 +125,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
......@@ -179,10 +189,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -190,19 +203,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return {
input_lin_results,
......@@ -270,11 +286,9 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -282,20 +296,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -303,17 +322,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
......@@ -333,6 +357,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
......@@ -353,6 +380,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -392,6 +422,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Matmul1 Dgrad2
......@@ -411,11 +444,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -423,20 +459,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -444,18 +485,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
return {
input_grads,
......@@ -464,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace self_norm_add {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
......@@ -21,7 +25,7 @@ extern THCState *state;
namespace multihead_attn {
namespace self_norm_add {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
......@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
CUDA_R_16F,
c_type,
output_lin_dim,
q_lin_results_ptr,
d_type,
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
compute_type,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
......@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
......@@ -202,11 +214,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -214,18 +229,24 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_lin_results.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>(
......@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens);
}
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
lyr_nrm_mean,
......@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()),
......@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
compute_type,
algo,
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -367,18 +389,23 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F,
b_type,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_weight_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
......@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
......@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -455,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
......@@ -475,11 +511,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
b_type,
output_lin_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_lin_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F,
a_type,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
b_type,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
c_type,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
d_type,
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
compute_type,
algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
......@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
......@@ -11,7 +11,14 @@
#include <cuda_fp16.h>
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
......@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
auto seeds = at::cuda::philox::unpack(philox_args);
......@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
......@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
return false;
}
int log2_ceil_native(int value) {
static int log2_ceil_native(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
......@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) {
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
......@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment