Commit ba0e5fa5 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify self_multihead_attn_bias_additive_mask.

parent c3ec9351
......@@ -4,7 +4,7 @@
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -21,7 +21,7 @@ extern THCState *state;
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -48,8 +48,8 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
......@@ -81,11 +81,12 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -93,18 +94,42 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // a_type
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // b_type
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
CUDA_R_16F,
rocblas_datatype_f16_r, // c_type
output_lin_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
q_lin_results_ptr,
rocblas_datatype_f16_r, // d_type
output_lin_dim,
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
......@@ -123,7 +148,31 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(bmm1_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
bool softmax_success = false;
if (is_training) {
......@@ -150,6 +199,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences);
}
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -168,12 +218,34 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs.copy_(output_biases);
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -181,20 +253,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // a_type
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // b_type
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // c_type
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim,
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
......@@ -263,11 +359,12 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -275,39 +372,89 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // a_type
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // b_type
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // c_type
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim,
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // TODO: CUBLAS_GEMM_DEFAULT_TENSOR_OP
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // a_type
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // b_type
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // c_type
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
......@@ -326,8 +473,30 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -345,8 +514,29 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
......@@ -362,7 +552,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -381,8 +571,30 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
......@@ -400,10 +612,32 @@ std::vector<torch::Tensor> bwd_cuda(
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: cublasGemmEx --> rocblas_gemm_ex (ok)
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -411,43 +645,92 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // a_type
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r, // b_type
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // c_type
embed_dim,
CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim,
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // a_type
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F,
output_lin_dim,
rocblas_datatype_f16_r, // b_type
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F,
rocblas_datatype_f16_r, // c_type
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......@@ -458,6 +741,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment