// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                  ck_tile::DeviceMem& b_k_n_dev_buf,
                  ck_tile::DeviceMem& c_m_n_dev_buf,
                  ck_tile::index_t M,
                  ck_tile::index_t N,
                  ck_tile::index_t K,
                  ck_tile::index_t stride_A,
                  ck_tile::index_t stride_B,
                  ck_tile::index_t stride_C,
                  ck_tile::index_t kbatch,
                  ck_tile::index_t batch_stride_A,
                  ck_tile::index_t batch_stride_B,
                  ck_tile::index_t batch_stride_C,
                  ck_tile::index_t batch_count,
                  int n_warmup,
                  int n_repeat)
{
    batched_gemm_basic_args args;
    args.p_a            = a_m_k_dev_buf.GetDeviceBuffer();
    args.p_b            = b_k_n_dev_buf.GetDeviceBuffer();
    args.p_c            = c_m_n_dev_buf.GetDeviceBuffer();
    args.kbatch         = kbatch;
    args.M              = M;
    args.N              = N;
    args.K              = K;
    args.stride_A       = stride_A;
    args.stride_B       = stride_B;
    args.stride_C       = stride_C;
    args.batch_stride_A = batch_stride_A;
    args.batch_stride_B = batch_stride_B;
    args.batch_stride_C = batch_stride_C;
    args.batch_count    = batch_count;

    float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
        args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});

    std::string op_name{"Gemm{MemBoundPipeline}"};

    std::size_t flop = std::size_t(2) * M * N * K;
    std::size_t num_byte =
        sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
    float tflops     = static_cast<float>(flop) / 1.E9 / ave_time;
    float gb_per_sec = num_byte / 1.E6 / ave_time;

    std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
              << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
              << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
              << std::endl;

    return ave_time;
}

int run_batched_gemm_example(int argc, char* argv[])
{
    auto [result, arg_parser] = create_args(argc, argv);
    if(!result)
        return -1;

    ck_tile::index_t M = arg_parser.get_int("m");
    ck_tile::index_t N = arg_parser.get_int("n");
    ck_tile::index_t K = arg_parser.get_int("k");

    ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
    ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
    ck_tile::index_t stride_C = arg_parser.get_int("stride_c");

    ck_tile::index_t batch_size = arg_parser.get_int("b");

    ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
    ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
    ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
    ck_tile::index_t batch_count    = arg_parser.get_int("batch_count");

    std::cout << "Received args: " << std::endl;
    std::cout << "batch_stride_A: " << batch_stride_A << '\n'
              << "batch_stride_B: " << batch_stride_B << '\n'
              << "batch_stride_C: " << batch_stride_C << '\n'
              << "batch_count: " << batch_count << std::endl;

    int n_warmup = arg_parser.get_int("warmup");
    int n_repeat = arg_parser.get_int("repeat");

    using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
    using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
    using CLayout = ck_tile::tensor_layout::gemm::RowMajor;

    using namespace ck_tile::literals;

    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
            {
                return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col},
                                                     {row * col, stride, 1_uz});
            }
            else
            {
                return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col},
                                                     {row * col, 1_uz, stride});
            }
        };

    auto f_get_default_stride = [](std::size_t row,
                                   std::size_t col,
                                   std::size_t stride,
                                   auto layout) {
        if(stride == 0)
        {
            // give a chance if stride is zero, return a default packed stride
            if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
            {
                return col;
            }
            else
            {
                return row;
            }
        }
        else
            return stride;
    };

    stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
    stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
    stride_C = f_get_default_stride(M, N, stride_C, CLayout{});

    ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
    ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
    ck_tile::HostTensor<CDataType> c_m_n_dev_result(
        f_host_tensor_descriptor(M, N, stride_C, CLayout{}));

    // TODO: add different init types

    ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
    ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);

    ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
    ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
    ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());

    a_m_k_dev_buf.ToDevice(a_m_k.data());
    b_k_n_dev_buf.ToDevice(b_k_n.data());
    c_m_n_dev_buf.SetZero();
    c_m_n_dev_result.SetZero();

    invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
                                           b_k_n_dev_buf,
                                           c_m_n_dev_buf,
                                           M,
                                           N,
                                           K,
                                           stride_A,
                                           stride_B,
                                           stride_C,
                                           batch_size,
                                           batch_stride_A,
                                           batch_stride_B,
                                           batch_stride_C,
                                           batch_count,
                                           n_warmup,
                                           n_repeat);

    c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
    bool pass = true;

    if(arg_parser.get_int("v") == 1)
    {
        ck_tile::HostTensor<CDataType> c_m_n_host_ref(
            f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
        c_m_n_host_ref.SetZero();

        ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
            a_m_k, b_k_n, c_m_n_host_ref);

        pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);

        std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
    }
    else if(arg_parser.get_int("v") == 2)
    {
        ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
            f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
        ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
        c_m_n_gpu_ref.SetZero();
        c_m_n_gpu_buf_ref.SetZero();

        ck_tile::reference_gemm_gpu<ADataType,
                                    BDataType,
                                    AccDataType,
                                    CDataType,
                                    ALayout,
                                    BLayout,
                                    CLayout>(a_m_k_dev_buf,
                                             b_k_n_dev_buf,
                                             c_m_n_gpu_buf_ref,
                                             M,
                                             N,
                                             K,
                                             stride_A,
                                             stride_B,
                                             stride_C,
                                             batch_stride_A,
                                             batch_stride_B,
                                             batch_stride_C,
                                             batch_count);

        c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
        pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);

        std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
    }

    return pass;
}
