"vscode:/vscode.git/clone" did not exist on "ddf39d3fcee2ed42284bfb90c3ea37b302090dad"
Commit ba0e5fa5 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify self_multihead_attn_bias_additive_mask.

parent c3ec9351
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd( ...@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -21,7 +21,7 @@ extern THCState *state; ...@@ -21,7 +21,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -81,11 +81,12 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -81,11 +81,12 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -93,18 +94,42 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -93,18 +94,42 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // a_type
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // b_type
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r, // c_type
output_lin_dim, output_lin_dim,
CUDA_R_32F, q_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r, // d_type
output_lin_dim,
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
...@@ -123,7 +148,31 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -123,7 +148,31 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(bmm1_results_ptr), static_cast<half*>(bmm1_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(bmm1_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (is_training) { if (is_training) {
...@@ -150,6 +199,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -150,6 +199,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences); attn_batches*q_seq_len/sequences);
} }
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -168,12 +218,34 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -168,12 +218,34 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs.copy_(output_biases); outputs.copy_(output_biases);
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -181,20 +253,44 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -181,20 +253,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // a_type
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // b_type
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // c_type
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(outputs.data_ptr()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); rocblas_datatype_f16_r, // d_type
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r, // compute_type
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -263,11 +359,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -263,11 +359,12 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -275,19 +372,44 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -275,19 +372,44 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // a_type
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // b_type
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // c_type
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r, // compute_type
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // TODO: CUBLAS_GEMM_DEFAULT_TENSOR_OP
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -295,19 +417,44 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -295,19 +417,44 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // a_type
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // b_type
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // c_type
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r, // compute_type
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
...@@ -326,8 +473,30 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -326,8 +473,30 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -346,7 +515,28 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,7 +515,28 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
...@@ -362,7 +552,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -362,7 +552,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches*q_seq_len/sequences, attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len, attn_batches*q_seq_len,
stream); stream);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -381,8 +571,30 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -381,8 +571,30 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -401,9 +613,31 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -401,9 +613,31 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: cublasGemmEx --> rocblas_gemm_ex (ok)
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -411,22 +645,46 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -411,22 +645,46 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // a_type
embed_dim, embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()), static_cast<const void*>(input_lin_output_grads.data_ptr()),
//static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r, // b_type
CUDA_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // c_type
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); rocblas_datatype_f16_r, // d_type
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r, // compute_type
algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -434,20 +692,45 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -434,20 +692,45 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // a_type
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r, // b_type
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r, // c_type
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, // d_type
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r, // compute_type
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
...@@ -458,6 +741,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -458,6 +741,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment