Commit 79f3caf8 authored by danyao12's avatar danyao12
Browse files

fix argc order

parent 3724ab55
......@@ -7,9 +7,9 @@ add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK 1
#define USING_MASK 0
#define RANGE_HDKO 2 // 0~2
#include <iostream>
......@@ -488,8 +488,8 @@ int run(int argc, char* argv[])
ck::index_t K = 80; // 64<K/O<=128
#endif
ck::index_t O = K;
ck::index_t G0 = 3;
ck::index_t G1 = 2;
ck::index_t G0 = 54;
ck::index_t G1 = 16;
float alpha = 1.f / std::sqrt(K);
......@@ -497,9 +497,6 @@ int run(int argc, char* argv[])
bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -513,7 +510,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 14)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -527,11 +524,10 @@ int run(int argc, char* argv[])
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[13]);
}
else
{
......@@ -544,6 +540,10 @@ int run(int argc, char* argv[])
exit(0);
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
......
......@@ -718,9 +718,6 @@ int run(int argc, char* argv[])
bool output_permute = true;
float p_drop = 0.3;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -734,7 +731,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 14)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -748,11 +745,10 @@ int run(int argc, char* argv[])
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[13]);
}
else
{
......@@ -765,6 +761,10 @@ int run(int argc, char* argv[])
exit(0);
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment