// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"

namespace ck_tile {
template <typename CrossReducePartitioner, typename ReduceReceivePipeline_>
struct ReduceReceiveKernel
{
    using ReduceReceivePipeline                = remove_cvref_t<ReduceReceivePipeline_>;
    static constexpr index_t TransferBlockSize = ReduceReceivePipeline::BlockSize;
    using DataType  = remove_cvref_t<typename ReduceReceivePipeline::DataType>;
    using ODataType = remove_cvref_t<typename ReduceReceivePipeline::ODataType>;

    struct ReduceReceiveKargs
    {
        const void* reduce_ptr;
        const void* receive_ptr;
        const void* output_ptr;
        index_t M;
        index_t N;
    };

    CK_TILE_HOST static constexpr ReduceReceiveKargs MakeKargs(const void* reduce_ptr,
                                                               const void* receive_ptr,
                                                               const void* output_ptr,
                                                               index_t M,
                                                               index_t N)
    {
        return ReduceReceiveKargs{reduce_ptr, receive_ptr, output_ptr, M, N};
    }

    CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
    {
        return ReduceReceivePipeline::GetSmemSize();
    }

    __host__ static constexpr auto GridSize(index_t M_size, index_t N_size)
    {
        return CrossReducePartitioner::GridSize(M_size, N_size);
    }

    CK_TILE_DEVICE void operator()(ReduceReceiveKargs kargs) const
    {
        const auto i_M                 = CrossReducePartitioner{}();
        const DataType* reduce_start = static_cast<const DataType*>(kargs.reduce_ptr);
        auto transfer_tensor_view      = [&]() {
            return make_naive_tensor_view<address_space_enum::global>(
                reduce_start,
                make_tuple(kargs.M, kargs.N),
                make_tuple(kargs.N, 1),
                number<ReduceReceivePipeline::Vector_N>{},
                number<1>{});
        }();
        auto transfer_block_window =
            make_tile_window(transfer_tensor_view,
                             make_tuple(number<ReduceReceivePipeline::Block_M>{},
                                        number<ReduceReceivePipeline::Block_N>{}),
                             {i_M, 0});

        const ODataType* output_start = static_cast<const ODataType*>(kargs.output_ptr);
        auto output_tensor_view       = [&]() {
            return make_naive_tensor_view<address_space_enum::global>(
                output_start,
                make_tuple(kargs.M, kargs.N),
                make_tuple(kargs.N, 1),
                number<ReduceReceivePipeline::Vector_N>{},
                number<1>{});
        }();
        auto output_block_window =
            make_tile_window(output_tensor_view,
                             make_tuple(number<ReduceReceivePipeline::Block_M>{},
                                        number<ReduceReceivePipeline::Block_N>{}),
                             {i_M, 0});

        __shared__ char smem_ptr[ReduceReceivePipeline::GetSmemSize()];

        ReduceReceivePipeline{}(transfer_block_window, output_block_window, smem_ptr);
        return;
    }
};
} // namespace ck_tile
