Commit 56cc306d authored by coderfeli's avatar coderfeli
Browse files

fix perf calc

parent 7572a691
......@@ -132,7 +132,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
......@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
ck::index_t N = 6144;
ck::index_t K = 8192;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 9;
ck::index_t sorted_tile_num = 8;
ck::index_t valid_tile_num = 8;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
......@@ -207,13 +207,14 @@ int main(int argc, char* argv[])
{
// use default case
}
else if(argc == 6)
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
......@@ -221,10 +222,15 @@ int main(int argc, char* argv[])
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf(
"arg4 to 5: N, K\n");
"arg4 to 5: N, K, tokens\n");
exit(0);
}
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
......@@ -235,7 +241,7 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
......@@ -246,7 +252,7 @@ int main(int argc, char* argv[])
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for (int i = 0; i < sorted_size; i++) {
int tile_off = i % valid_size;
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
......@@ -278,9 +284,9 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
......@@ -358,7 +364,7 @@ int main(int argc, char* argv[])
if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * valid_tile_num * N * K;
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t num_btype =
sizeof(A0DataType) * valid_tile_num * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_tile_num * N;
......
......@@ -200,13 +200,14 @@ int main(int argc, char* argv[])
{
// use default case
}
else if(argc == 6)
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
......@@ -214,7 +215,7 @@ int main(int argc, char* argv[])
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf(
"arg4 to 5: N, K\n");
"arg4 to 6: N, K, tokens\n");
exit(0);
}
......@@ -244,7 +245,7 @@ int main(int argc, char* argv[])
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for (int i = 0; i < sorted_size; i++) {
int tile_off = i % valid_size;
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment