#include "src/turbomind/models/llama/awq_sugon/lmdeploy_sugon.cuh"
#include "src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh"


template <typename T>
__global__ void add_kernel(int n,T* A,const T* B)
{
    int id = blockIdx.x * blockDim.x + threadIdx.x;
    if(id >= n) return;

    A[id]=A[id]+B[id];
}

template <typename T>
__global__ void assign_kernel(int n,T* A,const T* B)
{
    int id = blockIdx.x * blockDim.x + threadIdx.x;
    if(id >= n) return;

    A[id]=B[id];
}


template <typename T>
void assign_fun(cudaStream_t stream, T* A,const  T* B,int size)
{
    int num_kernels=size;
    assign_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,A,B);
}

#define INSTANTIATEASSIGN(T)  \
template void assign_fun(cudaStream_t stream, T* A,const T* B,int size);

INSTANTIATEASSIGN(__half)
INSTANTIATEASSIGN(float)
INSTANTIATEASSIGN(half2)
INSTANTIATEASSIGN(uint)


template <typename T>
void PrintScale(cudaStream_t stream,const T* data,int size,int flag,int m,int n){
    printf("start printf ****\n");
    int input_size=size;
    T* h_data;
    h_data=new T[input_size];

    T* d_data;
    cudaMalloc((void**)&d_data, input_size * sizeof(T));

    //进行初始化
    // for(int i=0;i<input_size;i++)
    // {
    //     h_data[i] = __float2half(2.0f);
    // }

    // cudaMemcpy(d_data, h_data, input_size * sizeof(T), cudaMemcpyHostToDevice);
    assign_fun<T>(stream,d_data,data,input_size);
    cudaStreamSynchronize(stream);

    cudaMemcpy(h_data,d_data, input_size * sizeof(T), cudaMemcpyDeviceToHost);

    if(flag!=0)
    {
        std::string file_name="/FrameWork/nvidia_file/elsetest/data"+std::to_string(flag)+".bin";
        std::ofstream outfile(file_name, std::ios::binary);
        if (!outfile) {
            std::cerr << "Failed to open the file for writing." << std::endl;
        }
        outfile.write(reinterpret_cast<const char*>(h_data), m*n*sizeof(T));
        outfile.close();
    }

    if constexpr (std::is_same_v<T, half>)
    {
        for(int i=0;i<input_size;i++)
        {
            printf("%f ",__half2float(h_data[i]));
        }
    }
    else if constexpr(std::is_same_v<T, half2>)
    {
        for(int i=0;i<input_size;i++)
        {
            printf("x:%f  y:%f ",__half2float(h_data[i].data[0]),__half2float(h_data[i].data[1]));
        }
    }
    else if constexpr(std::is_same_v<T, uint>)
    {
        for(int i=0;i<input_size;i++)
        {
            printf(" %u ",h_data[i]);
        }
    }
    printf("\n");
    delete[] h_data;
    cudaFree(d_data);
    return ;
}


#define INSTANTIATEPRINT(T)  \
template void PrintScale(cudaStream_t stream,const T* data,int size,int flag,int m,int n);

INSTANTIATEPRINT(__half)
INSTANTIATEPRINT(float)
INSTANTIATEPRINT(half2)
INSTANTIATEPRINT(uint32_t)

template <typename T>
__global__ void input_padding_kernel(int num_kernels,T* output,const T* input,int m,int k,int group_size,int count)
{
    int id = blockIdx.x * blockDim.x + threadIdx.x;
    if(id >= num_kernels) return;

    int j=id%(k+count*group_size);
    int i=id/(k+count*group_size);

    if(j<k)
    {
        output[i*(k+count*group_size)+j]=input[i*(k)+j];
    }
    else
    {
        output[i*(k+count*group_size)+j]=0.f;
    } 
}


template <typename T>
void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount)
{
    //input的size是[m,k],output的size是[m,n+group_size]
    //
    int num_kernels=m*(k+pad_groupcount*group_size);
    input_padding_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels, output,input,m,k,group_size,pad_groupcount);
}


#define INSTANTIATEINPUTPADING(T)  \
template void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);

INSTANTIATEINPUTPADING(__half)

