// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"

template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
    using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
    using f_shape  = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
                                               typename Ts_::WarpPerBlock_0,
                                               typename Ts::WarpTile_0,
                                               typename Ts_::BlockTile_1,
                                               typename Ts_::WarpPerBlock_0,
                                               typename Ts::WarpTile_0>;
    using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
                                                           typename Ts_::GDataType,
                                                           typename Ts_::DDataType,
                                                           typename Ts_::AccDataType,
                                                           typename Ts_::ODataType,
                                                           typename Ts_::AScaleDataType,
                                                           typename Ts_::GScaleDataType,
                                                           typename Ts_::DScaleDataType,
                                                           typename Ts_::YSmoothScaleDataType,
                                                           typename Ts_::TopkWeightDataType,
                                                           typename Ts_::IndexDataType,
                                                           ck_tile::Gelu, // TODO: hardcoded
                                                           f_shape,
                                                           f_traits>

        using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
    using f_partitioner  = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
    using f_kernel       = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;

    const dim3 grids                       = f_kernel::GridSize(a);
    constexpr dim3 blocks                  = f_kernel::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = 1;

    auto kargs = f_kernel::MakeKargs(a);
    if(s.log_level_ > 0)
        std::cout << ", " << f_kernel::GetName() << std::flush;

    return ck_tile::launch_kernel(
        s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}
