Commit 79f3caf8 authored by danyao12's avatar danyao12
Browse files

fix argc order

parent 3724ab55
...@@ -7,9 +7,9 @@ add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_pe ...@@ -7,9 +7,9 @@ add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp) add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp) add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp) add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1 #define USING_MASK 0
#define RANGE_HDKO 2 // 0~2 #define RANGE_HDKO 2 // 0~2
#include <iostream> #include <iostream>
...@@ -488,8 +488,8 @@ int run(int argc, char* argv[]) ...@@ -488,8 +488,8 @@ int run(int argc, char* argv[])
ck::index_t K = 80; // 64<K/O<=128 ck::index_t K = 80; // 64<K/O<=128
#endif #endif
ck::index_t O = K; ck::index_t O = K;
ck::index_t G0 = 3; ck::index_t G0 = 54;
ck::index_t G1 = 2; ck::index_t G1 = 16;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
...@@ -497,9 +497,6 @@ int run(int argc, char* argv[]) ...@@ -497,9 +497,6 @@ int run(int argc, char* argv[])
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -513,7 +510,7 @@ int run(int argc, char* argv[]) ...@@ -513,7 +510,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -527,11 +524,10 @@ int run(int argc, char* argv[]) ...@@ -527,11 +524,10 @@ int run(int argc, char* argv[])
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -544,6 +540,10 @@ int run(int argc, char* argv[]) ...@@ -544,6 +540,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl; std::cout << "time_kernel: " << time_kernel << std::endl;
......
...@@ -718,9 +718,6 @@ int run(int argc, char* argv[]) ...@@ -718,9 +718,6 @@ int run(int argc, char* argv[])
bool output_permute = true; bool output_permute = true;
float p_drop = 0.3; float p_drop = 0.3;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -734,7 +731,7 @@ int run(int argc, char* argv[]) ...@@ -734,7 +731,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -748,11 +745,10 @@ int run(int argc, char* argv[]) ...@@ -748,11 +745,10 @@ int run(int argc, char* argv[])
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -765,6 +761,10 @@ int run(int argc, char* argv[]) ...@@ -765,6 +761,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl; std::cout << "time_kernel: " << time_kernel << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment