// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"

namespace ck_tile {

struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
    : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
                                          /* AsyncCopy = */ true,
                                          /* NumPrefetchV = */ 2>
{
    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
    {
        constexpr index_t kBlockSize = Problem::kBlockSize;
        constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
        constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;

        constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);

        // this should align with MakeQDramTileDistribution()
        constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
        static_assert(0 < ElemPerThread);
        return min(ElemPerThread, MaxVectorSize);
    }

    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
    {
        constexpr index_t kBlockSize = Problem::kBlockSize;
        constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
        constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;

        constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);

        constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
        static_assert(0 < ElemPerThread);
        constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);

        constexpr index_t KPerThread     = kMaxVecLoad;
        constexpr index_t KThreads       = kKPerBlock / KPerThread;
        constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
        constexpr index_t NumWarps       = kBlockSize / get_warp_size();
        constexpr index_t MPerThread     = kMPerBlock / (MThreadPerWarp * NumWarps);

        return make_static_tile_distribution(
            tile_distribution_encoding<sequence<1>,
                                       tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
                                             sequence<KThreads, KPerThread>>,
                                       tuple<sequence<1>, sequence<1, 2>>,
                                       tuple<sequence<1>, sequence<2, 0>>,
                                       sequence<1, 2>,
                                       sequence<0, 1>>{});
    }

    /*
        template <typename Problem>
        CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
        {
            constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim ==
                                                             Problem::BlockFmhaShape::kSubQKHeaddim)
                                               ? Problem::BlockFmhaShape::kSubQKHeaddim
                                               : Problem::BlockFmhaShape::kK0;

            using GemmProblem = BlockGemmProblem<
                typename Problem::QDataType,
                typename Problem::KDataType,
                typename Problem::SaccDataType,
                Problem::kNumGemm0Warps * get_warp_size(),
                TileGemmShape<
                    sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
       BlockGemmK>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename
       Problem::BlockFmhaShape::Gemm0WarpTile>>;

            constexpr auto warp_gemm = []() {
                constexpr index_t WarpGemmM =
       Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); static_assert(WarpGemmM == 4 ||
       WarpGemmM == 16 || WarpGemmM == 32);

                if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
                             std::is_same_v<typename Problem::KDataType, half_t> &&
                             std::is_same_v<typename Problem::SaccDataType, float>)
                {
                    if constexpr(WarpGemmM == 32)
                        return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
                    else if constexpr(WarpGemmM == 16)
                        return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
                    else // WarpGemmM == 4
                        return WarpGemmMfmaF16F16F32M4N64K16{};
                }
                else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
                                  std::is_same_v<typename Problem::KDataType, bf16_t> &&
                                  std::is_same_v<typename Problem::SaccDataType, float>)
                {
                    if constexpr(WarpGemmM == 32)
                        return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
                    else if constexpr(WarpGemmM == 16)
                        return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
                    else // WarpGemmM == 4
                        return WarpGemmMfmaBf16Bf16F32M4N64K16{};
                }
                else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
                                  std::is_same_v<typename Problem::KDataType, fp8_t> &&
                                  std::is_same_v<typename Problem::SaccDataType, float>)
                {
                    static_assert(WarpGemmM == 32);

                    // TODO: hard coded here. Otherwise, it may incorrect result
                    constexpr index_t swizzle_factor = 4;
                    return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
                        swizzle_factor>{};
                } // TODO - bf8_t
            }();

            using BlockGemmPolicy =
                BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
                                                     typename Problem::KDataType,
                                                     typename Problem::SaccDataType,
                                                     typename
       Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>;

            if constexpr(1 < Problem::kNumGemm0Warps)
                return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
            else
                return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
        }
    */

    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
    {
        // TODO: this is for 3d layout
        using QDataType = remove_cvref_t<typename Problem::QDataType>;
        return static_cast<index_t>(16 / sizeof(QDataType));
    }

    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
    {
        constexpr index_t kBlockSize = Problem::kBlockSize;
        constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
        constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;

        constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
        static_assert(0 < ElemPerThread);
        constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());

        constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
            make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
            make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
            number<kKPack>{},
            number<1>{});

        constexpr auto q_lds_block_desc = transform_tensor_descriptor(
            q_lds_block_desc_0,
            make_tuple(
                make_pass_through_transform(number<kMPerBlock>{}),
                make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
            make_tuple(sequence<1>{}, sequence<0, 2>{}),
            make_tuple(sequence<0>{}, sequence<1>{}));

        return q_lds_block_desc;
    }

    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
    {
        return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
               sizeof(typename Problem::QDataType);
    }

    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
    {
        // assume Q can reuse the shared memory with K or V
        // assume Dropout can reuse the shared memory with V
        return max(GetSmemSizeQ<Problem>(),
                   GetSmemSizeK<Problem>() +
                       max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
    }
};

} // namespace ck_tile
