Commit 513abed6 authored by ltqin's avatar ltqin
Browse files

firt version for 2 kernel

parent 71e2a917
...@@ -12,6 +12,7 @@ add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_m ...@@ -12,6 +12,7 @@ add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_m
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp) add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp) add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp) add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_batched_multihead_attention_backward_v4 batched_multihead_attention_backward_v4.cpp)
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp) add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
......
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -259,7 +259,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -259,7 +259,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = false;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true; bool time_kernel = true;
...@@ -271,8 +271,8 @@ int run(int argc, char* argv[]) ...@@ -271,8 +271,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 12;
ck::index_t G1 = 6; ck::index_t G1 = 16;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment