// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"

namespace ck_tile {

// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BlockGemmASmemBSmemCRegV1
{
    using Problem        = remove_cvref_t<Problem_>;
    using Policy         = remove_cvref_t<Policy_>;
    using ADataType      = remove_cvref_t<typename Problem::ADataType>;
    using BDataType      = remove_cvref_t<typename Problem::BDataType>;
    using CDataType      = remove_cvref_t<typename Problem::CDataType>;
    using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;

    static constexpr index_t kBlockSize = Problem::kBlockSize;

    // C += A * B
    template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
    CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
                                   const ABlockWindow& a_block_window,
                                   const BBlockWindow& b_block_window) const
    {
        static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
                          std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
                          std::is_same_v<CDataType, typename CBlockTensor::DataType>,
                      "wrong!");

        constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
        constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
        constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];

        static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
                          KPerBlock == BlockGemmShape::kK,
                      "wrong!");

        constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

        using WG = remove_cvref_t<decltype(config.template at<0>())>;

        constexpr index_t MWarp = config.template at<1>();
        constexpr index_t NWarp = config.template at<2>();

        constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
        constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
        constexpr index_t KIterPerWarp = KPerBlock / WG::kK;

        // constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
        // constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
        // constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;

        // const index_t iMWarp = get_warp_id() / NWarp;
        // const index_t iNWarp = get_warp_id() % NWarp;

        // if(threadIdx.x == 0 && blockIdx.x==0) {
        //     printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d  NPerBlockPerIter %d  KPerBlockPerIter %d \n", MWarp, NWarp,  MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter); 
        // }
        // MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64  NPerBlockPerIter 64  KPerBlockPerIter 8

        
        auto a_warp_window_tmp = make_tile_window(
            a_block_window.get_bottom_tensor_view(),
            make_tuple(MPerBlock, KPerBlock),
            {0, 0},
            Policy::template MakeALDSTileDistribution<Problem>());
        auto b_warp_window_tmp = make_tile_window(
            b_block_window.get_bottom_tensor_view(),
            make_tuple(NPerBlock, KPerBlock),
            {0, 0},
            Policy::template MakeBLDSTileDistribution<Problem>());

        auto a_block_tensor = load_tile(a_warp_window_tmp);
        auto b_block_tensor = load_tile(b_warp_window_tmp);


        // if (threadIdx.x == 0) {
        //     printf("0\n");
        //     constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
        //     sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
        //         sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
        //             constexpr auto i_j_idx = make_tuple(idx0, idx1);
        //             printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
        //         });
        //         printf("\n");
        //     });
        // }
        // __syncthreads();
        using AWarpDstr = typename WG::AWarpDstr;
        using BWarpDstr = typename WG::BWarpDstr;
        using CWarpDstr = typename WG::CWarpDstr;

        using AWarpTensor = typename WG::AWarpTensor;
        using BWarpTensor = typename WG::BWarpTensor;
        using CWarpTensor = typename WG::CWarpTensor;

        constexpr auto a_warp_y_lengths =
            to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
        constexpr auto b_warp_y_lengths =
            to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
        constexpr auto c_warp_y_lengths =
            to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());

        constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
        constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
        constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};

        // hot loop:
        static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
            static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
                // read A warp tensor from A Block window
                AWarpTensor a_warp_tensor;

                a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
                    merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
                    merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));

                static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
                    // read B warp tensor from B block tensor
                    BWarpTensor b_warp_tensor;

                    b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
                        merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
                        merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));

                    // read C warp tensor from C block tensor
                    CWarpTensor c_warp_tensor;

                    c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
                        merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
                        merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));

                    // warp GEMM
                    WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);

                    // write C warp tensor into C block tensor
                    c_block_tensor.set_y_sliced_thread_data(
                        merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
                        merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
                        c_warp_tensor.get_thread_buffer());
                });
            });
        });
    }

    CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
    {
        constexpr index_t MPerBlock = BlockGemmShape::kM;
        constexpr index_t NPerBlock = BlockGemmShape::kN;

        constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

        using WG = remove_cvref_t<decltype(config.template at<0>())>;

        constexpr index_t MWarp = config.template at<1>();
        constexpr index_t NWarp = config.template at<2>();

        constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
        constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);

        constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
            sequence<>,
            tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
            tuple<sequence<1, 2>>,
            tuple<sequence<1, 1>>,
            sequence<1, 2>,
            sequence<0, 0>>{};

        constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
            c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});

        constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);

        auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
        return c_block_tensor;
    }

    // C = A * B
    template <typename ABlockTensorTmp, typename BBlockWindow>
    CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
                                   const BBlockWindow& b_block_window) const
    {
        auto c_block_tensor = MakeCBlockTile();
        operator()(c_block_tensor, a_block_tensor_tmp, b_block_window);
        return c_block_tensor;
    }
};
        // construct A-warp-window
        // auto a_warp_window_tmp = make_tile_window(
        //     a_block_window.get_bottom_tensor_view(),
        //     make_tuple(number<WG::kM>{}, number<WG::kK>{}),
        //     a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
        //     make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
//         array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
//             {a_warp_window_tmp}};

//         for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
//         {
//             for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
//             {
//                 move_tile_window(a_warp_windows(mIter)(kIter),
//                                  {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
//             }
//         }
// #else
//         statically_indexed_array<
//             statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
//             MIterPerWarp>
//             a_warp_windows;

//         static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
//             static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
//                 a_warp_windows(mIter)(kIter) = a_warp_window_tmp;

//                 move_tile_window(a_warp_windows(mIter)(kIter),
//                                  {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
//             });
//         });
// #endif

        // construct B-warp-window
//         auto b_warp_window_tmp = make_tile_window(
//             b_block_window.get_bottom_tensor_view(),
//             make_tuple(number<WG::kN>{}, number<WG::kK>{}),
//             b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
//             make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));

// #if 0 // FIXME: using array will cause register spill
//         array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
//             {b_warp_window_tmp}};

//         for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
//         {
//             for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
//             {
//                 move_tile_window(b_warp_windows(nIter)(kIter),
//                                  {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
//             }
//         }
// #else
//         statically_indexed_array<
//             statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
//             NIterPerWarp>
//             b_warp_windows;

//         static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
//             static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
//                 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;

//                 move_tile_window(b_warp_windows(nIter)(kIter),
//                                  {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
//             });
//         });
// #endif
} // namespace ck_tile
