Commit 8bdbb502 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify encdec_multihead_attn

parent ba0e5fa5
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace cublas_gemmex { namespace rocblas_gemm_ex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd( ...@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemm_ex
} // end namespace encdec } // end namespace encdec
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward."); m.def("forward", &multihead_attn::encdec::rocblas_gemm_ex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward."); m.def("backward", &multihead_attn::encdec::rocblas_gemm_ex::bwd, "Encdec Multihead Attention Backward.");
} }
...@@ -85,10 +85,12 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -85,10 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
// TODO (OK)
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -96,20 +98,45 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -96,20 +98,45 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
CUDA_R_32F, q_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_q_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_q_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -117,19 +144,44 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -117,19 +144,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_kv_dim,
// batches_kv,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// k_lin_results_ptr,
// CUDA_R_16F,
// output_lin_kv_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
b_layout_n, b_layout_n,
...@@ -146,8 +198,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -146,8 +198,29 @@ std::vector<torch::Tensor> fwd_cuda(
beta, beta,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim_q,
// batch_stride_q,
// beta,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -191,6 +264,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -191,6 +264,7 @@ std::vector<torch::Tensor> fwd_cuda(
} }
// Matmul2 // Matmul2
// TODO (OK)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
b_layout_n, b_layout_n,
...@@ -208,10 +282,32 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -208,10 +282,32 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -219,20 +315,45 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,20 +315,45 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); algo,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_q_results, input_lin_q_results,
...@@ -311,11 +432,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -311,11 +432,12 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
// TODO (OK)
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -323,20 +445,45 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -323,20 +445,45 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_grads.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -344,19 +491,44 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -344,19 +491,44 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches_q,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
b_layout_n, b_layout_n,
...@@ -374,9 +546,31 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -374,9 +546,31 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
b_layout_t, b_layout_t,
...@@ -394,7 +588,28 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -394,7 +588,28 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim_kv,
// batch_stride_kv,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -416,6 +631,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -416,6 +631,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
b_layout_n, b_layout_n,
...@@ -433,9 +649,31 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -433,9 +649,31 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim_kv,
// batch_stride_kv,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim_q,
// batch_stride_q,
// attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
b_layout_t, b_layout_t,
...@@ -453,10 +691,32 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -453,10 +691,32 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim_q,
// batch_stride_q,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim_kv,
// batch_stride_kv,
// attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -464,21 +724,46 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,21 +724,46 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()), static_cast<void*>(input_q_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); algo,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// output_lin_q_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_q_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_q_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -486,20 +771,45 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -486,20 +771,45 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_q_dim,
// batches_q,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_q_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_q_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -507,21 +817,46 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -507,21 +817,46 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_kv_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); rocblas_datatype_f16_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_kv,
// output_lin_kv_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(k_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_kv_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_kv_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle, // TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -529,18 +864,43 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -529,18 +864,43 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUDA_R_16F, rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
CUDA_R_32F, rocblas_datatype_f32_r,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_kv_dim,
// batches_kv,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(k_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_kv_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_kv_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment