"...composable_kernel_rocm.git" did not exist on "9868fd0245eea8f905cd756ab5c6c145c447c596"
Commit f906b23d authored by Anthony Chang's avatar Anthony Chang
Browse files

account for extra flops/bytes from normalization

parent 9f6dbb55
...@@ -50,8 +50,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl ...@@ -50,8 +50,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl
< Row, Col, Row, F16, F16, F16, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceInstance = ck::tensor_operation::host:: using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
ReferenceGemmLayernorm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -199,9 +204,11 @@ int main(int argc, char* argv[]) ...@@ -199,9 +204,11 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; // extra 5MN flops due to: bias + gamma + beta + norm_sub + norm_div,
std::size_t num_btype = // excluding reduction steps
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; std::size_t flop = std::size_t(2) * M * N * K + std::size_t(5) * M * N;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(CDataType) * 3 * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -218,8 +225,15 @@ int main(int argc, char* argv[]) ...@@ -218,8 +225,15 @@ int main(int argc, char* argv[])
auto ref_gemm = ReferenceInstance{}; auto ref_gemm = ReferenceInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k, b_k_n, c0_n_bias, c0_n_gamma, c0_n_beta, c_m_n_host_result, a_element_op, b_element_op, c_element_op); b_k_n,
c0_n_bias,
c0_n_gamma,
c0_n_beta,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment