"vscode:/vscode.git/clone" did not exist on "ad2953f6608514f2bc3d9996dcd354eb0d9cb7a7"
Commit 513abed6 authored by ltqin's avatar ltqin
Browse files

firt version for 2 kernel

parent 71e2a917
......@@ -12,6 +12,7 @@ add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_m
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_batched_multihead_attention_backward_v4 batched_multihead_attention_backward_v4.cpp)
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
......
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -259,7 +259,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[])
{
bool do_verification = true;
bool do_verification = false;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
......@@ -271,8 +271,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6;
ck::index_t G0 = 12;
ck::index_t G1 = 16;
bool input_permute = false;
bool output_permute = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment