Commit 9319318d authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix namespace for pybind11

Fix rocblas_gemmex namespace

Fix namespace

Clean up comments
parent 83181423
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace encdec {
namespace rocblas_gemm_ex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -151,6 +151,6 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::rocblas_gemm_ex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::rocblas_gemm_ex::bwd, "Encdec Multihead Attention Backward.");
m.def("forward", &multihead_attn::encdec::rocblas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::rocblas_gemmex::bwd, "Encdec Multihead Attention Backward.");
}
......@@ -692,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemm_ex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace rocblas_gemm_ex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::rocblas_gemm_ex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemm_ex::bwd, "Self Multihead Attention with Bias -- Backward.");
m.def("forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
......@@ -83,11 +83,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -105,36 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr, //
rocblas_datatype_f16_r, //
output_lin_dim, //
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
......@@ -156,24 +133,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
......@@ -214,7 +174,6 @@ std::vector<torch::Tensor> fwd_cuda(
}
// Matmul2
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
......@@ -236,29 +195,10 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim*attn_batches,
head_dim,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs.copy_(output_biases);
// Output Linear
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -283,28 +223,6 @@ std::vector<torch::Tensor> fwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
......@@ -372,11 +290,8 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -401,28 +316,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
......@@ -447,29 +342,9 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
......@@ -491,27 +366,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
......@@ -533,24 +389,6 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
......@@ -565,7 +403,6 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
......@@ -587,27 +424,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
......@@ -629,26 +447,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Input Linear Dgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -673,30 +472,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
......@@ -721,29 +498,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......@@ -754,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace self {
namespace rocblas_gemm_ex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -126,7 +126,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::rocblas_gemm_ex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemm_ex::bwd, "Self Multihead Attention Backward.");
m.def("forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward.");
}
......@@ -24,7 +24,7 @@ extern THCState *state;
namespace multihead_attn {
namespace self {
namespace cublas_gemmex {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -80,10 +80,8 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -108,28 +106,8 @@ std::vector<torch::Tensor> fwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
......@@ -151,24 +129,6 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
bool softmax_success = false;
......@@ -212,7 +172,6 @@ std::vector<torch::Tensor> fwd_cuda(
}
// Matmul2
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
......@@ -234,27 +193,8 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim*attn_batches,
head_dim,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
// Output Linear
// TODO
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -279,27 +219,6 @@ std::vector<torch::Tensor> fwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
......@@ -367,11 +286,8 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -396,28 +312,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// TODO (OOK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
......@@ -442,28 +338,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_t,
b_layout_n,
......@@ -485,27 +361,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
......@@ -527,24 +384,6 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
......@@ -566,7 +405,6 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_n,
......@@ -588,27 +426,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum( state,
a_layout_n,
b_layout_t,
......@@ -630,27 +449,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim,
batch_stride,
attn_batches);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Input Linear Dgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -675,28 +475,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
// TODO (OK)
THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
......@@ -721,27 +501,6 @@ std::vector<torch::Tensor> bwd_cuda(
algo,
solution_index,
flags));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......@@ -750,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment