#include "src/fastertransformer/kernels/ck_kernels/kernels.h"
#include "src/fastertransformer/kernels/ck_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "src/fastertransformer/kernels/ck_kernels/utils.h"
#include "ck_extensions/ck_utils.h"
#include "src/fastertransformer/utils/logger.h"
#include "hip/hip_runtime.h"

namespace fastertransformer {

template<typename T, typename WeightType>
CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner()
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    int device{-1};
    check_hip_error(hipGetDevice(&device));
    sm_ = getSMVersion();
    check_hip_error(hipDeviceGetAttribute(&multi_processor_count_, hipDeviceAttributeMultiprocessorCount, device));
}


template<typename T, typename WeightType>
CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner()
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
}


template<typename T, typename WeightType>
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(const T*          A,
                                                            const WeightType* B,
                                                            const T*          weight_scales,
                                                            const T*          biases,
                                                            T*                C,
                                                            int               m,
                                                            int               n,
                                                            int               k,
                                                            ActivationType    activation_type,
                                                            char*             workspace_ptr,
                                                            const size_t      workspace_bytes
                                                            )
                                                            // hipStream_t      stream)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);


}

template<typename T, typename WeightType>
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T*          A,
                                                   const WeightType* B,
                                                   const T*          weight_scales,
                                                   T*                C,
                                                   int               m,
                                                   int               n,
                                                   int               k,
                                                   char*             workspace_ptr,
                                                   const size_t      workspace_bytes
                                                   )
                                                //    hipStream_t      stream)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
}


template<typename T, typename WeightType>
int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m, const int n, const int k)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    // These are the min tile sizes for each config, which would launch the maximum number of blocks
    const int max_grid_m = (m + 31) / 32;
    const int max_grid_n = (n + 127) / 128;
    // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
    return max_grid_m * max_grid_n * split_k_limit * 4;
}

// =============================== Specialization T == WeightType =======================================
template<typename WeightType>
void CutlassFpAIntBGemmRunner<float, WeightType>::gemm_bias_act(const float*      A,
                                                                const WeightType* B,
                                                                const float*      weight_scales,
                                                                const float*      biases,
                                                                float*            C,
                                                                int               m,
                                                                int               n,
                                                                int               k,
                                                                ActivationType    activation_type,
                                                                char*             workspace_ptr,
                                                                const size_t      workspace_bytes
                                                                )
                                                                // hipStream_t      stream)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    FT_CHECK_WITH_INFO(false, "Attempting to run mixed gemm bias act when the types are the same is an error.");
}

template<typename WeightType>
void CutlassFpAIntBGemmRunner<float, WeightType>::gemm(const float*      A,
                                                       const WeightType* B,
                                                       const float*      weight_scales,
                                                       float*            C,
                                                       int               m,
                                                       int               n,
                                                       int               k,
                                                       char*             workspace_ptr,
                                                       const size_t      workspace_bytes
                                                       )
                                                    //    hipStream_t      stream)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    FT_CHECK_WITH_INFO(false, "Attempting to run mixed gemm when the types are the same is an error.");
}

template<typename WeightType>
int CutlassFpAIntBGemmRunner<float, WeightType>::getWorkspaceSize(const int m, const int n, const int k)
{
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    return 0;
}

}  // namespace fastertransformer