"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "4facbe998c9c9007a0b34ca524da3b373a038c60"
Commit 61ee7e8a authored by Chao Liu's avatar Chao Liu
Browse files

test freq

parent 997b469c
...@@ -160,6 +160,8 @@ int main(int argc, char* argv[]) ...@@ -160,6 +160,8 @@ int main(int argc, char* argv[])
ck::index_t BatchStrideC = -1; ck::index_t BatchStrideC = -1;
float alpha = 1; float alpha = 1;
int nrepeat = 1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -183,7 +185,7 @@ int main(int argc, char* argv[]) ...@@ -183,7 +185,7 @@ int main(int argc, char* argv[])
BatchCount = std::stoi(argv[8]); BatchCount = std::stoi(argv[8]);
} }
else if(argc == 18) else if(argc == 19)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -207,6 +209,8 @@ int main(int argc, char* argv[]) ...@@ -207,6 +209,8 @@ int main(int argc, char* argv[])
BatchStrideC = std::stoi(argv[16]); BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]); alpha = std::stof(argv[17]);
nrepeat = std::stoi(argv[18]);
} }
else else
{ {
...@@ -348,19 +352,22 @@ int main(int argc, char* argv[]) ...@@ -348,19 +352,22 @@ int main(int argc, char* argv[])
return 0; return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); for(int i = 0; i < nrepeat; ++i)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< gemm.GetTypeString() << std::endl; << " GB/s, " << gemm.GetTypeString() << std::endl;
}
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment