Commit 7b13018e authored by letaoqin's avatar letaoqin
Browse files

change file run_grouped_multihead_attention_bias_forward.inc name

parent c703632f
add_example_executable(example_batched_flash_attention_forward batched_gemm_multihead_attention_forward.cpp)
add_example_executable(example_batched_flash_attention_bias_forward batched_gemm_multihead_attention_bias_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_gemm_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_bias_forward batched_gemm_multihead_attention_bias_forward.cpp)
add_example_executable(example_grouped_multihead_attention_bias_forward grouped_mutihead_attention_bias_forward.cpp)
add_example_executable(example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp)
......
......@@ -327,6 +327,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_bias_forward.inc"
#include "run_grouped_multihead_attention_bias_forward_v2.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -416,8 +416,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
static auto
......@@ -425,8 +425,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment