"...composable_kernel_rocm.git" did not exist on "85fc91c3218c1d85169ed1fe95eef7b07942e648"
Commit ad93b411 authored by Anthony Chang's avatar Anthony Chang
Browse files

calculate correct flops and bytes

parent ec9c2b5e
...@@ -492,11 +492,17 @@ int run(int argc, char* argv[]) ...@@ -492,11 +492,17 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// TODO ANT: add dQ/dK/dV flops & bytes // 5 GEMM ops in total:
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; // S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N + std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
BatchCount; size_t(2) * BatchCount +
sizeof(LSEDataType) * M * BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment