// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <hip/hip_runtime.h>

#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>

#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"

#if 1
template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
    // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
    constexpr bool kPadM = false;
    constexpr bool kPadN = false;
    constexpr bool kPadK = false;

    constexpr bool kTilePermute = false;
    // The rank and permutation will also be generate out by the CodeGen part.
    constexpr ck_tile::index_t kOutputRank = 2;

    constexpr int kBlockPerCu = 1;

    // This part comes from the Codegen
    constexpr ck_tile::index_t M_Tile = 128;
    constexpr ck_tile::index_t N_Tile = 128;
    constexpr ck_tile::index_t K_Tile = 32;

    constexpr ck_tile::index_t M_Warp = 2;
    constexpr ck_tile::index_t N_Warp = 2;
    constexpr ck_tile::index_t K_Warp = 1;

    constexpr ck_tile::index_t M_Warp_Tile = 32;
    constexpr ck_tile::index_t N_Warp_Tile = 32;
    constexpr ck_tile::index_t K_Warp_Tile = 8;

    // Whether doing the CShuffle (transpose before the global memory), depending on the output
    // layout.
    constexpr bool CShuffleEpilogue =
        std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;

    using CodegenGemmShape =
        ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
                               ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
                               ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;

    using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;

    using CodegenGemmTraits =
        ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
    using CodegenPipelineProblem = ck_tile::
        GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
    using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;

    using GemmEpilogue = ck_tile::CShuffleEpilogue<
        ck_tile::CShuffleEpilogueProblem<AccDataType,
                                         CDataType,
                                         CLayout,
                                         CodegenPipelineProblem::kBlockSize,
                                         TilePartitioner::MPerBlock,
                                         TilePartitioner::NPerBlock,
                                         M_Warp,
                                         N_Warp,
                                         M_Warp_Tile,
                                         N_Warp_Tile,
                                         K_Warp_Tile,
                                         CodegenPipelineProblem::TransposeC>>;

    using CodegenFlatmmPipeline =
        ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;

    // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
    // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
    using Kernel = ck_tile::FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;

    auto kargs = Kernel::MakeKernelArgs(args);

    const dim3 grids      = Kernel::GridSize(args.M, args.N, args.k_batch);
    constexpr dim3 blocks = Kernel::BlockSize();

#if FEIFEI_DEBUG
    printf("[FEIFEI] --- flatmm_calc() ---\n");
    printf("[FEIFEI] BlockPerCu = %d\n", static_cast<int>(kBlockPerCu));
    printf("[FEIFEI] BlockTile M = %d\n", static_cast<int>(M_Tile));
    printf("[FEIFEI] BlockTile N = %d\n", static_cast<int>(N_Tile));
    printf("[FEIFEI] BlockTile K = %d\n", static_cast<int>(K_Tile));
    printf("[FEIFEI] WavePerBlock M = %d\n", static_cast<int>(M_Warp));
    printf("[FEIFEI] WavePerBlock N = %d\n", static_cast<int>(N_Warp));
    printf("[FEIFEI] WavePerBlock K = %d\n", static_cast<int>(K_Warp));
    printf("[FEIFEI] WaveTile M = %d\n", static_cast<int>(M_Warp_Tile));
    printf("[FEIFEI] WaveTile N = %d\n", static_cast<int>(N_Warp_Tile));
    printf("[FEIFEI] WaveTile K = %d\n", static_cast<int>(K_Warp_Tile));
    printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z);
    printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z);
#endif

    if(!Kernel::IsSupportedArgument(kargs))
    {
        throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
    }

    if(s.log_level_ > 0)
    {
        std::cout << "Launching kernel with args:"
                  << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
                  << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
                  << std::endl;
    }

    float ave_time = ck_tile::launch_kernel(
        s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

    return ave_time;
}
#else
template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
    // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
    constexpr bool kPadM = false;
    constexpr bool kPadN = false;
    constexpr bool kPadK = false;

    constexpr int kBlockPerCu = 1;

    // This part comes from the Codegen
    /*constexpr ck_tile::index_t M_Tile = 128;
    constexpr ck_tile::index_t N_Tile = 128;
    constexpr ck_tile::index_t K_Tile = 32;

    constexpr ck_tile::index_t M_Warp = 1;
    constexpr ck_tile::index_t N_Warp = 4;
    constexpr ck_tile::index_t K_Warp = 1;

    constexpr ck_tile::index_t M_Warp_Tile = 32;
    constexpr ck_tile::index_t N_Warp_Tile = 32;
    constexpr ck_tile::index_t K_Warp_Tile = 8;*/

    constexpr ck_tile::index_t M_Tile = 128;
    constexpr ck_tile::index_t N_Tile = 128;
    constexpr ck_tile::index_t K_Tile = 64;

    constexpr ck_tile::index_t M_Warp = 2;
    constexpr ck_tile::index_t N_Warp = 2;
    constexpr ck_tile::index_t K_Warp = 1;

    constexpr ck_tile::index_t M_Warp_Tile = 32;
    constexpr ck_tile::index_t N_Warp_Tile = 32;
    constexpr ck_tile::index_t K_Warp_Tile = 16;

    using CodegenGemmShape =
        ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
                               ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
                               ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;

    using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;

    using CodegenGemmTraits =
        ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
    using CodegenPipelineProblem = ck_tile::
        GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
    using GemmEpilogue        = ck_tile::CShuffleEpilogue<
        ck_tile::CShuffleEpilogueProblem<AccDataType,
                                         CDataType,
                                         CLayout,
                                         CodegenPipelineProblem::kBlockSize,
                                         TilePartitioner::MPerBlock,
                                         TilePartitioner::NPerBlock,
                                         M_Warp,
                                         N_Warp,
                                         M_Warp_Tile,
                                         N_Warp_Tile,
                                         K_Warp_Tile,
                                         CodegenPipelineProblem::TransposeC>>;

    using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
    using CodegenFlatmmPipeline =
        ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;

    // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
    // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
    using Kernel = ck_tile::FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;

    auto kargs = Kernel::MakeKernelArgs(args);

    const dim3 grids      = Kernel::GridSize(args.M, args.N, args.k_batch);
    constexpr dim3 blocks = Kernel::BlockSize();

#if FEIFEI_DEBUG
    /*using BlockFlatmmStruct = ck_tile::remove_cvref_t<decltype(CodegenFlatmmPolicy::template GetBlockFlatmm<CodegenPipelineProblem>())>;
    auto block_flatmm = BlockFlatmmStruct(); // struct BlockFlatmmASmemBSmemCRegV1
    //auto ADramTileDistr = CodegenFlatmmPolicy::template MakeADramTileDistribution<CodegenPipelineProblem>();

    auto kernel = Kernel{};
    using SplitKBatchOffset = typename Kernel::SplitKBatchOffset;
    SplitKBatchOffset splitk_batch_offset(args);
    auto gemm_tensor_views_tuple = Kernel::template MakeGemmTensorViews<ck_tile::memory_operation_enum::set>(
            args.a_ptr, 
            args.b_shuffle_ptr,
            args.c_ptr, 
            kargs, splitk_batch_offset);*/


    printf("[FEIFEI] --- flatmm_calc() ---\n");
    printf("[FEIFEI] BlockPerCu = %d\n", static_cast<int>(kBlockPerCu));
    printf("[FEIFEI] BlockTile M = %d\n", static_cast<int>(M_Tile));
    printf("[FEIFEI] BlockTile N = %d\n", static_cast<int>(N_Tile));
    printf("[FEIFEI] BlockTile K = %d\n", static_cast<int>(K_Tile));
    printf("[FEIFEI] WavePerBlock M = %d\n", static_cast<int>(M_Warp));
    printf("[FEIFEI] WavePerBlock N = %d\n", static_cast<int>(N_Warp));
    printf("[FEIFEI] WavePerBlock K = %d\n", static_cast<int>(K_Warp));
    printf("[FEIFEI] WaveTile M = %d\n", static_cast<int>(M_Warp_Tile));
    printf("[FEIFEI] WaveTile N = %d\n", static_cast<int>(N_Warp_Tile));
    printf("[FEIFEI] WaveTile K = %d\n", static_cast<int>(K_Warp_Tile));
    printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z);
    printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z);
#endif

    if(!Kernel::IsSupportedArgument(kargs))
    {
        throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
    }

    if(s.log_level_ > 0)
    {
        std::cout << "Launching kernel with args:"
                  << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
                  << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
                  << std::endl;
    }

    float ave_time = ck_tile::launch_kernel(
        s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

    return ave_time;
}
#endif

#include "run_flatmm_example.inc"

int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
