Commit 513abed6 authored by ltqin's avatar ltqin
Browse files

firt version for 2 kernel

parent 71e2a917
...@@ -12,6 +12,7 @@ add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_m ...@@ -12,6 +12,7 @@ add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_m
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp) add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp) add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp) add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_batched_multihead_attention_backward_v4 batched_multihead_attention_backward_v4.cpp)
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp) add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
......
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -259,7 +259,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -259,7 +259,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = false;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true; bool time_kernel = true;
...@@ -271,8 +271,8 @@ int run(int argc, char* argv[]) ...@@ -271,8 +271,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 12;
ck::index_t G1 = 6; ck::index_t G1 = 16;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
......
...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, 1,
false>{ false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
...@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
......
...@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
block_sync_lds(); block_sync_lds();
} }
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment