// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "metric.h"
#include "src/turbomind/macro.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <memory>
#include <vector>

namespace turbomind {

extern bool g_dump_kernel_info_once;
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);

void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type);
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);

class GemmS4F16 {
public:
    GemmS4F16();

    ~GemmS4F16();

    enum Type
    {
        kGemm,
        kFusedSiluFfn
    };

    void Measure(half*                C,
                 const uint*          A,
                 const half*          B,
                 const half2*         Q,
                 int                  m,
                 int                  n,
                 int                  k,
                 int                  group_size,
                 Type                 type,
                 std::vector<Metric>& metrics,
                 cudaStream_t         st);

    void Run(half*        C,
             const uint*  A,
             const half*  B,
             const half2* Q,
             int          m,
             int          n,
             int          k,
             int          group_size,
             Type         type,
             int          algo_id,
             cudaStream_t st);

private:
    struct Impl;
    std::unique_ptr<Impl> impl_;
};

}  // namespace turbomind
