#include "src/fastertransformer/kernels/ck_kernels/int8_gemm/int8_gemm.h"
#include "src/fastertransformer/kernels/ck_kernels/kernels.h"
#include "src/fastertransformer/utils/logger.h"

#include "ck_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "ck_extensions/ck_utils.h"
#include "ck/utility/sequence.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "hip/hip_runtime.h"

namespace fastertransformer {
template<typename T>
CutlassInt8GemmRunner<T>::CutlassInt8GemmRunner()
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);

}

template<typename T>
CutlassInt8GemmRunner<T>::~CutlassInt8GemmRunner()
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
}

template<ck::index_t... Is>
using S = ck::Sequence<Is...>;

template<typename T>
void CutlassInt8GemmRunner<T>::gemm(const int8_t* A,
                                    const int8_t* B,
                                    QuantMode     quant_mode,
                                    const float*  alpha_row,
                                    const float*  alpha_col,
                                    T*            C,
                                    int           m,
                                    int           n,
                                    int           k,
                                    char*         workspace_ptr,
                                    const size_t  workspace_bytes
                                    )
                                    // hipStream_t  stream)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    
    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    using PassThrough = ck::tensor_operation::element_wise::PassThrough;

    static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

    using Gemm = ck::tensor_operation::device::DeviceGemmDl< int8_t,  int8_t,  int8_t,    int32_t,     Row,     Row,     Row, PassThrough, PassThrough, PassThrough,    GemmDefault,   256,   128,   128,    16,  4,          4,          4,      1,       S<8, 2>,       S<8, 2>,      S<8, 1, 1, 4>,      S<2, 1, 128, 1>,  S<1, 2, 0, 3>,  S<1, 2, 0, 3>,       S<4, 1, 1, 4>,      S<1, 2, 0, 3>,        S<1, 1, 1, 4>,      S<2, 1, 4, 4>,      S<8, 1,  32, 1>,  S<0, 3, 1, 2>,  S<0, 3, 1, 2>,       S<1, 1, 4, 1>,      S<0, 3, 1, 2>,        S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>,               5,                  4>;

}

template<typename T>
int CutlassInt8GemmRunner<T>::getWorkspaceSize(const int m, const int n, const int k)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    // These are the min tile sizes for each config, which would launch the maximum number of blocks
    const int max_grid_m = (m + 31) / 32;
    const int max_grid_n = (n + 127) / 128;
    // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
    return max_grid_m * max_grid_n * split_k_limit * 4;
}

}  // namespace fastertransformer