// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/flatmm_uk.hpp"
#include <string>

// this is only a convenient structure for creating an example
// this is not part of the host API
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
struct FlatmmUkTypeConfig;

template <typename ST, typename SW, typename SQ, typename KW>
struct FlatmmUkTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
    using ADataType            = ck_tile::bf16_t;
    using GDataType            = ck_tile::bf16_t;
    using DDataType            = ck_tile::bf16_t;
    using AccDataType          = float;
    using ODataType            = ck_tile::bf16_t;
    using AScaleDataType       = ck_tile::remove_cvref_t<ST>;
    using GScaleDataType       = ck_tile::remove_cvref_t<SW>;
    using DScaleDataType       = ck_tile::remove_cvref_t<SW>;
    using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
    using TopkWeightDataType   = ck_tile::remove_cvref_t<KW>;
    using IndexDataType        = ck_tile::index_t;
};

template <typename ST, typename SW, typename SQ, typename KW>
struct FlatmmUkTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
{
    using ADataType            = ck_tile::fp16_t;
    using GDataType            = ck_tile::fp16_t;
    using DDataType            = ck_tile::fp16_t;
    using AccDataType          = float;
    using ODataType            = ck_tile::fp16_t;
    using AScaleDataType       = ck_tile::remove_cvref_t<ST>;
    using GScaleDataType       = ck_tile::remove_cvref_t<SW>;
    using DScaleDataType       = ck_tile::remove_cvref_t<SW>;
    using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
    using TopkWeightDataType   = ck_tile::remove_cvref_t<KW>;
    using IndexDataType        = ck_tile::index_t;
};

template <typename ST, typename SW, typename SQ, typename KW>
struct FlatmmUkTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
    using ADataType            = ck_tile::int8_t;
    using GDataType            = ck_tile::int8_t;
    using DDataType            = ck_tile::int8_t;
    using AccDataType          = int32_t;
    using ODataType            = ck_tile::bf16_t;
    using AScaleDataType       = ck_tile::remove_cvref_t<ST>;
    using GScaleDataType       = ck_tile::remove_cvref_t<SW>;
    using DScaleDataType       = ck_tile::remove_cvref_t<SW>;
    using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
    using TopkWeightDataType   = ck_tile::remove_cvref_t<KW>;
    using IndexDataType        = ck_tile::index_t;
};


struct flatmm_uk_args
{
    const void* a_ptr;              // [m, k], input token
    const void* b_ptr;              // [m, k], input token
    const void* c_ptr;                    // [m, k], output token (no need to do zeroing)
    void* d_ptr;                    // [m, k], output token (no need to do zeroing)
    void* dbg_int_ptr;                    // [m, k], output token (no need to do zeroing)
    void* dbg_bf16_ptr;                    // [m, k], output token (no need to do zeroing)
    void* dbg_fp32_ptr;                    // [m, k], output token (no need to do zeroing)

    ck_tile::index_t block_m;           // block_m, used to devide the input
    ck_tile::index_t hidden_size;       // k
    ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
    ck_tile::index_t num_tokens;        // input number of tokens for current iteration
    ck_tile::index_t num_experts;       // number of groups
    ck_tile::index_t topk;              // need this?

    ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};

// This is the public API, will be generated by script
struct flatmm_uk_traits
{
    std::string prec_i;  // input precision
    std::string prec_w;  // weight precision
    std::string prec_o;  // output precision
    std::string prec_st; // token scale data type
    std::string prec_sw; // weight scale data type
    std::string prec_sq; // smooth quant scale
    std::string prec_kw; // topk-weight data type
    int block_m;
    int gate_only;
    int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};

float flatmm_uk(flatmm_uk_traits, flatmm_uk_args, const ck_tile::stream_config&);
