// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>

template <typename InType, typename OutType, typename YScaleDataType_>
struct LayerNormTypeConfig;

template <typename OutType, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType, YScaleDataType_>
{
    using XDataType       = ck_tile::half_t;
    using YDataType       = OutType;
    using GammaDataType   = ck_tile::half_t;
    using BetaDataType    = ck_tile::half_t;
    using MeanDataType    = ck_tile::half_t;
    using InvStdDataType  = ck_tile::half_t;
    using ComputeDataType = float;
    using YScaleDataType  = YScaleDataType_;
};

template <typename OutType, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, YScaleDataType_>
{
    using XDataType       = ck_tile::bf16_t;
    using YDataType       = OutType;
    using GammaDataType   = ck_tile::bf16_t;
    using BetaDataType    = ck_tile::bf16_t;
    using MeanDataType    = ck_tile::bf16_t;
    using InvStdDataType  = ck_tile::bf16_t;
    using ComputeDataType = float;
    using YScaleDataType  = YScaleDataType_;
};

// runtime args
struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{
};

template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);

// This is the public API, will be generated by script
struct layernorm2d_fwd_traits
{
    std::string prec_i;
    std::string prec_o;
    std::string prec_s; // scale value, used as scale factor store out when fused_sweep=1
    bool save_mean_var;
    int fused_add;   // 0:no-add, 1:pre-add-store, 2:pre-add
    int fused_sweep; // 0:no-sweep, 1:dynamic-quant
};

float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
