// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "metric.h"
#include "src/turbomind/macro.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <memory>
#include <vector>
typedef struct ihipStream_t* hipStream_t;
extern void run_weight_only_gemm(const void *A,
                        const void *B0,
                        const void *B1,
                        void *C,
                        int M,
                        int N,
                        int K,
                        int StrideA,
                        int StrideB,
                        int StrideB_padded,     
                        int StrideC,
                        int Group,
                        void* splitK_padA_workspace,                  
                        int splitK_padA_workspace_elementSize,        
                        hipStream_t stream_id=0);

namespace turbomind {

extern bool g_dump_kernel_info_once;
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);

void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type);
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
template <typename T>
void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);

class GemmS4F16 {
public:
    GemmS4F16();

    ~GemmS4F16();

    enum Type
    {
        kGemm,
        kFusedSiluFfn
    };

    void Measure(half*                C,
                 const uint*          A,
                 const half*          B,
                 const half2*         Q,
                 int                  m,
                 int                  n,
                 int                  k,
                 int                  group_size,
                 Type                 type,
                 std::vector<Metric>& metrics,
                 cudaStream_t         st);

    void Run(half*        C,
             const uint*  A,
             const half*  B,
             const half2* Q,
             int          m,
             int          n,
             int          k,
             int          group_size,
             Type         type,
             int          algo_id,
             cudaStream_t st);

private:
    struct Impl;
    std::unique_ptr<Impl> impl_;
};

}  // namespace turbomind
