#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp"
#include "driver_gemm_xdlops_v3r1.hpp"

template <typename TInWei,
          typename TAcc,
          typename TOut,
          typename InLengths,
          typename WeiLengths,
          typename OutLengths,
          typename ConvStrides,
          typename ConvDilations,
          typename InLeftPads,
          typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
    const InLengths& in_n_hi_wi_g_c_lengths,
    const WeiLengths& wei_g_k_y_x_c_lengths,
    const OutLengths& out_n_ho_wo_g_k_lengths,
    const ConvStrides& conv_strides,
    const ConvDilations& conv_dilations,
    const InLeftPads& in_left_pads,
    const InRightPads& in_right_pads,
    const Tensor<TInWei>& in_n_hi_wi_g_c,
    const Tensor<TInWei>& wei_g_k_y_x_c,
    Tensor<TOut>& out_n_ho_wo_g_k,
    ck::index_t nrepeat)
{
    using namespace ck;

    std::cout << __func__ << std::endl;

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
    constexpr auto I4 = Number<4>{};

    DeviceMem in_n_hi_wi_g_c_device_buf(sizeof(TInWei) * in_n_hi_wi_g_c.mDesc.GetElementSpace());
    DeviceMem wei_g_k_y_x_c_device_buf(sizeof(TInWei) * wei_g_k_y_x_c.mDesc.GetElementSpace());
    DeviceMem out_n_ho_wo_g_k_device_buf(sizeof(TOut) * out_n_ho_wo_g_k.mDesc.GetElementSpace());

    in_n_hi_wi_g_c_device_buf.ToDevice(in_n_hi_wi_g_c.mData.data());
    wei_g_k_y_x_c_device_buf.ToDevice(wei_g_k_y_x_c.mData.data());
    out_n_ho_wo_g_k_device_buf.ToDevice(out_n_ho_wo_g_k.mData.data());

    const auto in_n_hi_wi_g_c_desc  = make_naive_tensor_descriptor_packed(in_n_hi_wi_g_c_lengths);
    const auto wei_g_k_y_x_c_desc   = make_naive_tensor_descriptor_packed(wei_g_k_y_x_c_lengths);
    const auto out_n_ho_wo_g_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_g_k_lengths);

#if 0
    // [M, N, K0, K1] = [256, 128, 4, 4] for fp32
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmMPerBlock = 256;
    constexpr index_t GemmNPerBlock = 128;
    constexpr index_t GemmKPerBlock = 4;

    constexpr index_t GemmMPerXDL = 32;
    constexpr index_t GemmNPerXDL = 32;
    constexpr index_t GemmK1       = 4;

    constexpr index_t MRepeat = 4;
    constexpr index_t NRepeat = 2;

    using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1   = Sequence<1, 4, 4>;
    using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
    constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;

    using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1   = Sequence<1, 2, 4>;
    using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
    constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;

    constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
    // [M, N, K0, K1] = [128, 128, 4, 4] for fp32
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmMPerBlock = 128;
    constexpr index_t GemmNPerBlock = 128;
    constexpr index_t GemmKPerBlock = 4;

    constexpr index_t GemmMPerXDL = 32;
    constexpr index_t GemmNPerXDL = 32;
    constexpr index_t GemmK1      = 4;

    constexpr index_t MRepeat = 2;
    constexpr index_t NRepeat = 2;

    using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1   = Sequence<1, 2, 4>;
    using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
    constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;

    using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1   = Sequence<1, 2, 4>;
    using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
    constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;

    constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
    // [M, N, K0, K1] = [256, 256, 4, 8] for fp16
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmMPerBlock = 256;
    constexpr index_t GemmNPerBlock = 256;
    constexpr index_t GemmKPerBlock = 4;

    constexpr index_t GemmMPerXDL = 32;
    constexpr index_t GemmNPerXDL = 32;
    constexpr index_t GemmK1      = 8;

    constexpr index_t MRepeat = 4;
    constexpr index_t NRepeat = 4;

    using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1   = Sequence<1, 4, 8>;
    using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;

    using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1   = Sequence<1, 4, 8>;
    using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;

    constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
    // [M, N, K0, K1] = [256, 128, 4, 8] for fp16
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmMPerBlock = 256;
    constexpr index_t GemmNPerBlock = 128;
    constexpr index_t GemmKPerBlock = 4;

    constexpr index_t GemmMPerXDL = 32;
    constexpr index_t GemmNPerXDL = 32;
    constexpr index_t GemmK1      = 8;

    constexpr index_t MRepeat = 4;
    constexpr index_t NRepeat = 2;

    using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1   = Sequence<1, 4, 8>;
    using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;

    using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1   = Sequence<1, 2, 8>;
    using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;

    constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
    // [M, N, K0, K1] = [128, 256, 4, 8] for fp16
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmMPerBlock = 128;
    constexpr index_t GemmNPerBlock = 256;
    constexpr index_t GemmKPerBlock = 4;

    constexpr index_t GemmMPerXDL = 32;
    constexpr index_t GemmNPerXDL = 32;
    constexpr index_t GemmK1      = 8;

    constexpr index_t MRepeat = 2;
    constexpr index_t NRepeat = 4;

    using GemmABlockTransferThreadSliceLengths_GemmG_GemmK0_GemmM_GemmK1   = Sequence<1, 1, 2, 8>;
    using GemmABlockTransferThreadClusterLengths_GemmG_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;

    constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;

    using GemmBBlockTransferThreadSliceLengths_GemmG_GemmK0_GemmN_GemmK1   = Sequence<1, 1, 4, 8>;
    using GemmBBlockTransferThreadClusterLengths_GemmG_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;

    constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;

    constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
    // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
    constexpr index_t BlockSize = 256;

    constexpr index_t GemmMPerBlock = 128;
    constexpr index_t GemmNPerBlock = 128;
    constexpr index_t GemmKPerBlock = 4;

    constexpr index_t GemmMPerXDL = 32;
    constexpr index_t GemmNPerXDL = 32;
    constexpr index_t GemmK1      = 8;

    constexpr index_t MRepeat = 2;
    constexpr index_t NRepeat = 2;

    using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1   = Sequence<1, 2, 8>;
    using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;

    using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1   = Sequence<1, 2, 8>;
    using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;

    constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
    constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;

    constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif

    const auto descs =
        transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(in_n_hi_wi_g_c_desc,
                                                                          wei_g_k_y_x_c_desc,
                                                                          out_n_ho_wo_g_k_desc,
                                                                          conv_strides,
                                                                          conv_dilations,
                                                                          in_left_pads,
                                                                          in_right_pads,
                                                                          Number<GemmK1>{});

    const auto in_gemmg_gemmk0_gemmm_gemmk1_grid_desc  = descs[I0];
    const auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
    const auto out_gemmg_gemmm_gemmn_grid_desc         = descs[I2];

    // HACK: hacks that control index calculation when iterating over A, B, C matrix
    constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks =
        make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 0+: GemmG
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 1+: GemmK0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 2+: GemmM
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),  // 3+: GemmK1
                   make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 0-: GemmG
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 1-: GemmK0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 2-: GemmM
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1

    constexpr auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks =
        make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{},   // 0+: GemmG
                              Sequence<0, 0, 0, 0, 0, 0, 0>{},   // 1+: GemmK0
                              Sequence<0, 0, 0, 0, 0, 0, 0>{},   // 2+: GemmN
                              Sequence<0, 0, 0, 0, 0, 0, 0>{}),  // 3+: GemmK1
                   make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{},   // 0-: GemmG
                              Sequence<0, 0, 0, 0, 0, 0, 0>{},   // 1-: GemmK0
                              Sequence<0, 0, 0, 0, 0, 0, 0>{},   // 2-: GemmN
                              Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1

    constexpr auto out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
        make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 0+: M0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 1+: M0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 2+: N0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 3+: M1
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 4+: N1
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 5+: M2
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 6+: M3
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 7+: M4
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),  // 8+: N2
                   make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 0-: M0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 1-: M0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 2-: N0
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 3-: M1
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 4-: N1
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 5-: M2
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 6-: M3
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},   // 7-: M4
                              Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 8-: N2

    constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
        Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};

    constexpr auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
        Sequence<0, 0, 0, 0, 0, 0, 0>{};

    for(index_t i = 0; i < 5; ++i)
    {
        float ave_time = driver_gemm_xdlops_v3r1<
            BlockSize,
            TInWei,
            TAcc,
            TOut,
            InMemoryDataOperationEnum_t::Set,
            decltype(in_gemmg_gemmk0_gemmm_gemmk1_grid_desc),
            decltype(wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc),
            decltype(out_gemmg_gemmm_gemmn_grid_desc),
            GemmMPerBlock,
            GemmNPerBlock,
            GemmKPerBlock,
            GemmMPerXDL,
            GemmNPerXDL,
            GemmK1,
            MRepeat,
            NRepeat,
            GemmABlockTransferThreadSliceLengths_GemmG_GemmK0_GemmM_GemmK1,
            GemmABlockTransferThreadClusterLengths_GemmG_GemmK0_GemmM_GemmK1,
            Sequence<0, 2, 1, 3>,
            Sequence<0, 2, 1, 3>,
            3,
            GemmABlockTransferSrcScalarPerVector_GemmK1,
            GemmABlockTransferDstScalarPerVector_GemmK1,
            false, // don't move back src coordinate after threadwise copy
            GemmBBlockTransferThreadSliceLengths_GemmG_GemmK0_GemmN_GemmK1,
            GemmBBlockTransferThreadClusterLengths_GemmG_GemmK0_GemmN_GemmK1,
            Sequence<0, 2, 1, 3>,
            Sequence<0, 2, 1, 3>,
            3,
            GemmBBlockTransferSrcScalarPerVector_GemmK1,
            GemmBBlockTransferDstScalarPerVector_GemmK1,
            false, // don't move back src coordinate after threadwise copy
            Sequence<0, 3, 4, 1, 2, 8, 6, 5, 7>,
            8,
            GemmCThreadTransferDstScalarPerVector,
            decltype(in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks),
            decltype(wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks),
            decltype(out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
            decltype(in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
            decltype(wei_gemmg_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
            false // CAccessOrderMRepeatNRepeat
            >(static_cast<TInWei*>(in_n_hi_wi_g_c_device_buf.GetDeviceBuffer()),
              static_cast<TInWei*>(wei_g_k_y_x_c_device_buf.GetDeviceBuffer()),
              static_cast<TOut*>(out_n_ho_wo_g_k_device_buf.GetDeviceBuffer()),
              in_gemmg_gemmk0_gemmm_gemmk1_grid_desc,
              wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc,
              out_gemmg_gemmm_gemmn_grid_desc,
              in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks,
              wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks,
              out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
              in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
              wei_gemmg_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
              nrepeat);

        {
            const auto G = wei_g_k_y_x_c_lengths[I0];
            const auto N = out_n_ho_wo_g_k_lengths[I0];
            const auto K = out_n_ho_wo_g_k_lengths[I4];
            const auto C = wei_g_k_y_x_c_lengths[I4];

            const auto Ho = out_n_ho_wo_g_k_lengths[I1];
            const auto Wo = out_n_ho_wo_g_k_lengths[I2];

            const auto Y = wei_g_k_y_x_c_lengths[I2];
            const auto X = wei_g_k_y_x_c_lengths[I3];

            float perf = static_cast<float>((std::size_t(2) * G * N * K * Ho * Wo * C * Y * X)) /
                         (std::size_t(1000) * 1000 * 1000) / ave_time;

            std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
                      << std::endl;
        }
    }

    // copy result back to host
    out_n_ho_wo_g_k_device_buf.FromDevice(out_n_ho_wo_g_k.mData.data());
}
