Commit 61416180 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify self_multihead_attn_bias

Fix some spacing
parent 8bdbb502
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
namespace rocblas_gemm_ex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemm_ex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
m.def("forward", &multihead_attn::self_bias::rocblas_gemm_ex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::rocblas_gemm_ex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
......@@ -124,13 +124,13 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr,
c_type,
output_lin_dim,
q_lin_results_ptr,
d_type,
output_lin_dim,
q_lin_results_ptr,
d_type,
output_lin_dim,
compute_type,
algo,
solution_index,
flags));
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
......@@ -150,9 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Padded Softmax
......@@ -215,9 +215,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches);
// Output Linear
......@@ -238,13 +238,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<void*>(output_lin_results.data_ptr()),
c_type,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
d_type,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
solution_index,
flags));
// End-of-block Dropout-Add
......@@ -372,13 +372,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
c_type,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
solution_index,
flags));
// Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle,
......@@ -398,13 +398,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
c_type,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
d_type,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
solution_index,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
......@@ -424,9 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches);
// Matmul2 Dgrad2
......@@ -447,9 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
......@@ -489,9 +489,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
// Matmul1 Dgrad2
......@@ -512,7 +512,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches);
......@@ -536,13 +536,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_grads.data_ptr()),
c_type,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
d_type,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
solution_index,
flags));
// Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle,
......@@ -563,13 +563,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
c_type,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
d_type,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment