// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"

namespace ck_tile {

template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{
    using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;

    using GemmKernelArgs = typename Base::GemmKernelArgs;

    using ADataType = typename Base::ADataType;
    using BDataType = typename Base::BDataType;
    using CDataType = typename Base::CDataType;

    using TilePartitioner  = typename Base::TilePartitioner;
    using GemmPipeline     = typename Base::GemmPipeline;
    using EpiloguePipeline = typename Base::EpiloguePipeline;
    using ALayout          = typename Base::ALayout;
    using BLayout          = typename Base::BLayout;
    using CLayout          = typename Base::CLayout;

    struct BatchedGemmKernelArgs : GemmKernelArgs
    {
        index_t batch_stride_A;
        index_t batch_stride_B;
        index_t batch_stride_C;
        index_t batch_count;
    };

    using KernelArgs = BatchedGemmKernelArgs;

    __host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count)
    {
        return TilePartitioner::GridSize(M, N, batch_count);
    }

    __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }

    CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const void* a_ptr,
                                                                       const void* b_ptr,
                                                                       void* c_ptr,
                                                                       index_t M,
                                                                       index_t N,
                                                                       index_t K,
                                                                       index_t stride_A,
                                                                       index_t stride_B,
                                                                       index_t stride_C,
                                                                       index_t batch_stride_A,
                                                                       index_t batch_stride_B,
                                                                       index_t batch_stride_C,
                                                                       index_t batch_count)
    {
        return BatchedGemmKernelArgs{{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C},
                                     batch_stride_A,
                                     batch_stride_B,
                                     batch_stride_C,
                                     batch_count};
    }

    CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
    {
        return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
    }

    CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
    {
        const auto [i_m, i_n] = TilePartitioner{}();
        const auto i_batch    = __builtin_amdgcn_readfirstlane(blockIdx.z);

        //  options
        const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
        const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
        const ADataType* a_ptr    = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;

        const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
        const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
        const BDataType* b_ptr    = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;

        const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
        const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
        CDataType* c_ptr          = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;

        this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
    }
};

} // namespace ck_tile
