// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp"

namespace ck_tile {

//  A Tile Window: global memory
//  B Tile Window: global memory
//  C Distributed tensor: register
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy> // feifei TODO: add default policy
struct FlatmmPipelineAGmemBGmemCRegV1
{
    static constexpr auto I0 = number<0>{};
    static constexpr auto I1 = number<1>{};
    static constexpr auto I2 = number<2>{};

    using ADataType      = remove_cvref_t<typename Problem::ADataType>;
    using BDataType      = remove_cvref_t<typename Problem::BDataType>;
    using CDataType      = remove_cvref_t<typename Problem::CDataType>;
    using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;

    using ALayout = remove_cvref_t<typename Problem::ALayout>;
    using BLayout = remove_cvref_t<typename Problem::BLayout>;
    using CLayout = remove_cvref_t<typename Problem::CLayout>;

    using BlockFlatmm =
        remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;

    static constexpr index_t BlockSize = Problem::kBlockSize;

    static constexpr index_t kMPerBlock = BlockGemmShape::kM;
    static constexpr index_t kNPerBlock = BlockGemmShape::kN;
    static constexpr index_t kKPerBlock = BlockGemmShape::kK;

    static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
    static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
    static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }

    static constexpr bool kPadM = Problem::kPadM;
    static constexpr bool kPadN = Problem::kPadN;
    static constexpr bool kPadK = Problem::kPadK;

    CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
    {
        return integer_divide_ceil(sizeof(ADataType) *
                                       PipelinePolicy::template MakeALdsBlockDescriptor<Problem>()
                                           .get_element_space_size(),
                                   16) *
                   16 +
               sizeof(BDataType) * PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>()
                                       .get_element_space_size();
    }

    CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
    {
        return PipelinePolicy::template GetSmemSize<Problem>();
    }

    CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
    {
        return PipelinePolicy::IsTransposeC();
    }

    template <typename ADramBlockWindowTmp,
              typename BFlatBlockWindowTmp,
              typename AElementFunction,
              typename BElementFunction
#if FEIFEI_DEBUG
            , typename BDramBlockWindowTmp
#endif
    >
    CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
                                        const AElementFunction& a_element_func,
                                        const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
                                        const BElementFunction& b_element_func,
                                        index_t num_loop,
                                        void* p_smem
#if FEIFEI_DEBUG
                                        ,
                                        const BDramBlockWindowTmp& b_dram_block_window_tmp,
                                        int* dbg_int,
                                        float* dbg_fp32,
                                        void* dbg_f168
#endif
    ) const
    {
#if FEIFEI_DEBUG
        if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
        {
            printf("[PIPELN] FlatmmPipelinen():\n");
            printf("[PIPELN] num_loop = %d\n", num_loop);
        }

        uint32_t tidx = threadIdx.x;
        uint32_t tidy = threadIdx.y;
        uint32_t bidx = blockIdx.x;
        uint32_t bidy = blockIdx.y;
        uint32_t bdmx = blockDim.x;
        uint32_t bdmy = blockDim.y;
        uint32_t gdmx = gridDim.x;
        uint32_t gdmy = gridDim.y;
        uint32_t gid  = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx;

        half_t* dbg_f16 = static_cast<half_t*>(dbg_f168);
        for(int i = 0; i < DEBUG_CNT; i++)
        {
            dbg_int[gid * DEBUG_CNT + i]  = 1;
            dbg_fp32[gid * DEBUG_CNT + i] = 1.0f;
            dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert<ck_tile::half_t>(1.0f);
        }
#endif
        static_assert(
            std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
                std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
            "wrong!");

        static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
                          kNPerBlock == BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
                          kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
                      "wrong!");
        if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
        {
            printf("[PIPELN] kMPerBlock = %d, winN = %d\n", kMPerBlock,
                static_cast<int>(ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
            printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock,
                static_cast<int>(BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
            printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock,
                static_cast<int>(BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
            printf("[PIPELN] kKPerBlock = %d, winN = %d\n", kKPerBlock,
                static_cast<int>(ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}]));
        }

#if 1
        // feifei TODO: Implement gemm here
        // Get block flatmm
        auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1

        // A DRAM tile window for load
        auto a_copy_dram_window =                                               // tile_window_with_static_distribution
            make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),  // from kernel gemm_pad_views
                             make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
                             a_dram_block_window_tmp.get_window_origin(),
                             PipelinePolicy::template MakeADramTileDistribution<Problem>());

        // B DRAM tile window for load
        auto b_copy_dram_window =                                               // tile_window_with_static_distribution
            make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),  // from kernel gemm_pad_views
                             make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
                             b_dram_block_window_tmp.get_window_origin(),
                             PipelinePolicy::template MakeBDramTileDistribution<Problem>());

        // B flat DRAM window for load
        auto b_flat_distribution = PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
        auto b_flat_dram_window =                                               // tile_window_with_static_distribution
            make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),  // from kernel gemm_pad_views
                             make_tuple(number<kNPerBlock>{}, number<BlockSize>{} * 4),
                             b_flat_dram_block_window_tmp.get_window_origin(),
                             b_flat_distribution);

        // Prefetch -----------------------------------------------------------
        // global read 0
        auto a_block_tile = load_tile(a_copy_dram_window);
        auto b_block_tile = load_tile(b_copy_dram_window);
        auto b_flat_tile = load_tile(b_flat_dram_window);

#if FEIFEI_DEBUG 
        // debug A global load
        int a_block_tile_size_per_thread = a_block_tile.get_thread_buffer_size();
        if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
        {
            printf("[PIPELN] a_block_tile_size_per_thread = %d\n", a_block_tile_size_per_thread);
        }
        for(auto i = 0; i < a_block_tile_size_per_thread; i++)
        {
            dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i];
        }

        // debug B global load
        int b_block_tile_size_per_thread = b_block_tile.get_thread_buffer_size();
        if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
        {
            printf("[PIPELN] b_block_tile_size_per_thread = %d\n", b_block_tile_size_per_thread);
        }
        for(auto i = 0; i < b_block_tile_size_per_thread; i++)
        {
            //dbg_f16[gid * DEBUG_CNT + i] = b_block_tile.get_thread_buffer()[i];
        }

        // debug flat B global load
        int b_flat_tile_size_per_thread = b_flat_tile.get_thread_buffer_size();
        if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
        {
            printf("[PIPELN] b_flat_tile_size_per_thread = %d\n", b_flat_tile_size_per_thread);
        }
        for(auto i = 0; i < b_flat_tile_size_per_thread; i++)
        {
            //dbg_f16[gid * DEBUG_CNT + i + b_block_tile_size_per_thread + 4] = b_flat_tile.get_thread_buffer()[i];
        }

        return nullptr;
#endif

#if 0
        // move to 1
        move_tile_window(a_copy_dram_window, {0, kKPerBlock});
        move_tile_window(b_copy_dram_window, {0, kKPerBlock});

        // A tile in LDS
        ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
        constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
        auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);

        // A LDS tile window for store
        auto a_copy_lds_window = make_tile_window(      // tile_window_with_static_lengths
            a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
        // A LDS tile for block GEMM
        auto a_lds_gemm_window = make_tile_window(      // tile_window_with_static_lengths
            a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // LDS write 0
        if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
        {
            auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
                PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
            shuffle_tile(a_shuffle_tmp, a_block_tile);
            const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
            store_tile(a_copy_lds_window, a_block_tile_tmp);
        }
        else
        {
            store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
        }

        // B tile in LDS
        constexpr index_t a_lds_block_space_size_aligned = integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * 16;
        BDataType* p_b_lds = static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));

        constexpr auto b_lds_block_desc = PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
        auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);

        // B LDS tile window for store
        auto b_copy_lds_window = make_tile_window(
            b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
        // B LDS tile for block GEMM
        auto b_lds_gemm_window = make_tile_window(
            b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // Loop ---------------------------------------------------------------
        // Do flatmm
        block_sync_lds();
        block_flatmm(a_lds_gemm_window, b_lds_gemm_window
#if FEIFEI_DEBUG
                     ,
                     dbg_int,
                     dbg_fp32,
                     dbg_f168
#endif
        );

        // Tail ---------------------------------------------------------------

        return nullptr;
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
        // A tile in LDS
        /*ADataType* p_a_lds = static_cast<ADataType*>(p_smem);

        constexpr auto a_lds_block_desc =
            PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();

        auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);

        constexpr index_t a_lds_block_space_size_aligned =
            integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
            16;

        // B tile in LDS
        BDataType* p_b_lds = static_cast<BDataType*>(
            static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));

        constexpr auto b_lds_block_desc =
            PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();

        auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);

        // A LDS tile window for store
        auto a_copy_lds_window = make_tile_window(
            a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // B LDS tile window for store
        auto b_copy_lds_window = make_tile_window(
            b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // A LDS tile for block GEMM
        auto a_lds_gemm_window = make_tile_window(
            a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // B LDS tile for block GEMM
        auto b_lds_gemm_window = make_tile_window(
            b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // Block GEMM
        auto block_gemm = BlockFlatmm();

        // Acc register tile
        auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};

        // prefetch
        // global read 0
        //auto a_block_tile = load_tile(a_copy_dram_window);
        //auto b_block_tile = load_tile(b_copy_dram_window);

        {
            // move to 1
            move_tile_window(a_copy_dram_window, {0, kKPerBlock});
            move_tile_window(b_copy_dram_window, {0, kKPerBlock});

            // initialize C
            tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);

            // LDS write 0
            if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
            {
                auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
                    PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
                shuffle_tile(a_shuffle_tmp, a_block_tile);
                const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
                store_tile(a_copy_lds_window, a_block_tile_tmp);
            }
            else
            {
                store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
            }

            // LDS write 0
            if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
            {
                auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
                    PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
                shuffle_tile(b_shuffle_tmp, b_block_tile);
                const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
                store_tile(b_copy_lds_window, b_block_tile_tmp);
            }
            else
            {
                store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
            }
        }

        index_t iCounter = num_loop - 1;
        while(iCounter > 0)
        {
            // global read i + 1
            a_block_tile = load_tile(a_copy_dram_window);
            b_block_tile = load_tile(b_copy_dram_window);

            block_sync_lds();

            // GEMM i
            block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);

            block_sync_lds();

            // move to i + 2
            move_tile_window(a_copy_dram_window, {0, kKPerBlock});
            move_tile_window(b_copy_dram_window, {0, kKPerBlock});

            // LDS write i + 1
            const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
            store_tile(a_copy_lds_window, a_block_tile_tmp);

            // LDS write i + 1
            if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
            {
                auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
                    PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
                shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
                store_tile(b_copy_lds_window,
                           tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
            }
            else
            {
                const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
                store_tile(b_copy_lds_window, b_block_tile_tmp);
            }

            iCounter--;
        }

        // tail
        {
            block_sync_lds();

            // GEMM num_loop - 1
            block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
        }

        int c_block_tile_size_per_thread = c_block_tile.get_thread_buffer_size();
        if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
        {
            printf("[PIPELN] c_block_tile_size_per_thread = %d\n", c_block_tile_size_per_thread);
        }
        for(auto i = 0; i < c_block_tile_size_per_thread; i++)
        {
            //dbg_fp32[gid * DEBUG_CNT + i] = c_block_tile.get_thread_buffer()[i];
            dbg_fp32[gid * DEBUG_CNT + i] = 3.12f;
            c_block_tile.get_thread_buffer()[i] = 1.23f;
        }
        return c_block_tile;*/
////////////////////////////////////////////////////////////////////////////////////////////////////


#else
        // A tile in LDS
        ADataType* p_a_lds = static_cast<ADataType*>(p_smem);

        constexpr auto a_lds_block_desc =
            PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();

        auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);

        constexpr index_t a_lds_block_space_size_aligned =
            integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
            16;

        // B tile in LDS
        BDataType* p_b_lds = static_cast<BDataType*>(
            static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));

        constexpr auto b_lds_block_desc =
            PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();

        auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);

        // A DRAM tile window for load
        auto a_copy_dram_window =
            make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
                             make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
                             a_dram_block_window_tmp.get_window_origin(),
                             PipelinePolicy::template MakeADramTileDistribution<Problem>());

        // A LDS tile window for store
        auto a_copy_lds_window = make_tile_window(
            a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // B DRAM tile window for load
        auto b_copy_dram_window =
            make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
                             make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
                             b_dram_block_window_tmp.get_window_origin(),
                             PipelinePolicy::template MakeBDramTileDistribution<Problem>());

        // B LDS tile window for store
        auto b_copy_lds_window = make_tile_window(
            b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // A LDS tile for block GEMM
        auto a_lds_gemm_window = make_tile_window(
            a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // B LDS tile for block GEMM
        auto b_lds_gemm_window = make_tile_window(
            b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

        // Block GEMM
        auto block_gemm = BlockGemm();

        // Acc register tile
        auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};

        // return c_block_tile;

        // prefetch
        // global read 0
        auto a_block_tile = load_tile(a_copy_dram_window);
        auto b_block_tile = load_tile(b_copy_dram_window);

        {
            // move to 1
            move_tile_window(a_copy_dram_window, {0, kKPerBlock});
            move_tile_window(b_copy_dram_window, {0, kKPerBlock});

            // initialize C
            tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);

            // LDS write 0
            if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
            {
                auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
                    PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
                shuffle_tile(a_shuffle_tmp, a_block_tile);
                const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
                store_tile(a_copy_lds_window, a_block_tile_tmp);
            }
            else
            {
                store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
            }

            // LDS write 0
            if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
            {
                auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
                    PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
                shuffle_tile(b_shuffle_tmp, b_block_tile);
                const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
                store_tile(b_copy_lds_window, b_block_tile_tmp);
            }
            else
            {
                store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
            }
        }

        index_t iCounter = num_loop - 1;
        while(iCounter > 0)
        {
            // global read i + 1
            a_block_tile = load_tile(a_copy_dram_window);
            b_block_tile = load_tile(b_copy_dram_window);

            block_sync_lds();

            // GEMM i
            block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);

            block_sync_lds();

            // move to i + 2
            move_tile_window(a_copy_dram_window, {0, kKPerBlock});
            move_tile_window(b_copy_dram_window, {0, kKPerBlock});

            // LDS write i + 1
            const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
            store_tile(a_copy_lds_window, a_block_tile_tmp);

            // LDS write i + 1
            if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
            {
                auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
                    PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
                shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
                store_tile(b_copy_lds_window,
                           tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
            }
            else
            {
                const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
                store_tile(b_copy_lds_window, b_block_tile_tmp);
            }

            iCounter--;
        }

        // tail
        {
            block_sync_lds();

            // GEMM num_loop - 1
            block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
        }

        return c_block_tile;
#endif
    }

    template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp
#if FEIFEI_DEBUG
    , typename BDramBlockWindowTmp
#endif
    >    
    CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
                                   const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
                                   index_t num_loop,
                                   void* p_smem
#if FEIFEI_DEBUG
                                   ,
                                   const BDramBlockWindowTmp& b_dram_block_window_tmp,
                                   int* dbg_int,
                                   float* dbg_fp32,
                                   void* dbg_f168
#endif
    ) const
    {
        return operator()(
            a_dram_block_window_tmp,
            [](const ADataType & a) { return a; },
            b_flat_dram_block_window_tmp,
            [](const BDataType & b) { return b; },
            num_loop,
            p_smem
#if FEIFEI_DEBUG
            ,
            b_dram_block_window_tmp,
            dbg_int,
            dbg_fp32,
            dbg_f168
#endif
        );
    }
};

} // namespace ck_tile
