Commit 61416180 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify self_multihead_attn_bias

Fix some spacing
parent 8bdbb502
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_bias { namespace self_bias {
namespace cublas_gemmex { namespace rocblas_gemm_ex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd( ...@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemm_ex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias::rocblas_gemm_ex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias::rocblas_gemm_ex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
...@@ -124,13 +124,13 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -124,13 +124,13 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr, q_lin_results_ptr,
c_type, c_type,
output_lin_dim, output_lin_dim,
q_lin_results_ptr, q_lin_results_ptr,
d_type, d_type,
output_lin_dim, output_lin_dim,
compute_type, compute_type,
algo, algo,
solution_index, solution_index,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -150,9 +150,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -150,9 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
...@@ -215,9 +215,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -215,9 +215,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
...@@ -238,13 +238,13 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -238,13 +238,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
c_type, c_type,
embed_dim, embed_dim,
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
d_type, d_type,
embed_dim, embed_dim,
compute_type, compute_type,
algo, algo,
solution_index, solution_index,
flags)); flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
...@@ -372,13 +372,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -372,13 +372,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
c_type, c_type,
embed_dim, embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
d_type, d_type,
embed_dim, embed_dim,
compute_type, compute_type,
algo, algo,
solution_index, solution_index,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, THCublasCheck(rocblas_gemm_ex(handle,
...@@ -398,13 +398,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -398,13 +398,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
c_type, c_type,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
d_type, d_type,
embed_dim, embed_dim,
compute_type, compute_type,
algo, algo,
solution_index, solution_index,
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -424,9 +424,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -424,9 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -447,9 +447,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -447,9 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -489,9 +489,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -489,9 +489,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
...@@ -512,7 +512,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -512,7 +512,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches);
...@@ -536,13 +536,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -536,13 +536,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
c_type, c_type,
embed_dim, embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
d_type, d_type,
embed_dim, embed_dim,
compute_type, compute_type,
algo, algo,
solution_index, solution_index,
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, THCublasCheck(rocblas_gemm_ex(handle,
...@@ -563,13 +563,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -563,13 +563,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
c_type, c_type,
embed_dim, embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
d_type, d_type,
embed_dim, embed_dim,
compute_type, compute_type,
algo, algo,
solution_index, solution_index,
flags)); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment