#include "hip/hip_runtime.h"
/*******************************************************************************************
 * This file contains the specific implementation of general helper functions.
 ******************************************************************************************/

#include "common.h"

size_t pitch_J;
size_t pitch_sp_num;
size_t pitch_react_num;
size_t pitch_n_vars;
size_t pitch_n_vars2;
size_t pitch_seulex_tmp_ptr;
size_t pitch_table;

__device__ __constant__ size_t pitch_J_d;
__device__ __constant__ size_t pitch_sp_num_d;
__device__ __constant__ size_t pitch_react_num_d;
__device__ __constant__ size_t pitch_n_vars_d;
__device__ __constant__ size_t pitch_n_vars2_d;
__device__ __constant__ size_t pitch_seulex_tmp_ptr_d;
__device__ __constant__ size_t pitch_table_d;

__device__ __constant__ size_t size_d;
__device__ __constant__ size_t sp_num_d;
__device__ __constant__ size_t react_num_d;

__device__ __constant__ size_t n_seq_d[k_max+1];
__device__ __constant__ REAL   coeff_d[(k_max+1)*(k_max+1)];


REAL t_end_h = 1000000.;

REAL *dt_sum_d = nullptr;

int *real_index = nullptr;

int *real_num = nullptr;

int *real_num_total = nullptr;

hipStream_t Stream_opencc[Stream_num];
hipEvent_t  Event[Stream_num];

void malloc_Host_(void **p, int size, const char *funname, const char *file, int line)
{
#ifdef __HIPCC__
	hipError_t Status = hipHostMalloc(p, size, hipHostMallocDefault);
#else
	hipError_t Status = hipHostAlloc(p, size, hipHostMallocDefault);
#endif
    if(Status != hipSuccess){
       MPI_PRINTF("Memory allocate error ! Can not allocate enough momory in fun %s ( file %s, line %d )\n", funname, file, line);
       exit(EXIT_FAILURE);
    }
}

__global__ void cuda_mem_value_init_g(REAL value, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height)
{
    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int i = blockIdx.y * blockDim.y + threadIdx.y;

    if (i < width && j < height)
    {
        *(ptr + i + j*pitch) = value;
    }
}

void cuda_mem_value_init(REAL value, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height, dim3 blockset) {
    
    dim3 griddim, blockdim;
    set_block_grid2d(height, width, blockset, griddim, blockdim);

    cuda_mem_value_init_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(value, ptr, width, pitch, height);
}

__global__ void cuda_copy_g(REAL *src, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height, REAL *dt_sum_d, REAL t_end_h)
{
    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int i = blockIdx.y * blockDim.y + threadIdx.y;

    if (i < width && j < height && access_data(dt_sum_d, j) <= t_end_h)
    {
        *(ptr + i + j*pitch) = *(src + i + j*pitch); 
    }
}

void cuda_copy(REAL *src, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height, dim3 blockset) {
    
    dim3 griddim, blockdim;
    set_block_grid2d(height, width, blockset, griddim, blockdim);

    cuda_copy_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(src, ptr, width, pitch, height, dt_sum_d, t_end_h);
}

#ifdef __NVCC__
static inline __device__ double atomicMax(double *addr, double value) {
  double old = *addr, assumed;
  if (old >= value) return old;
  do {
    assumed = old;
    old = atomicCAS((unsigned long long int *)addr, __double_as_longlong(assumed),
                    __double_as_longlong(value));

  } while (old != assumed);

  return old;
}

static inline __device__ double atomicMin(double *addr, double value) {
  double old = *addr, assumed;
  if (old <= value) return old;
  do {
    assumed = old;
    old = atomicCAS((unsigned long long int *)addr, __double_as_longlong(assumed),
                    __double_as_longlong(value));

  } while (old != assumed);

  return old;
}
#endif

__global__ void cuda_max_init_g(REAL *ptr, unsigned int height)
{
    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;

    if (j < height)
    {
        *ptr = 0.;
    }
}

__global__ void cuda_max_g(REAL *src, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height)
{
    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int i = blockIdx.y * blockDim.y + threadIdx.y;

    if (i < width && j < height)
    {

        atomicMax(ptr, *(src + i + j*pitch)); 
    }
}

void cuda_max(REAL *src, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height, dim3 blockset) {
    
    dim3 griddim, blockdim;
    set_block_grid2d(height, width, blockset, griddim, blockdim);

    cuda_max_init_g<<<1, 1, 0, Stream_opencc[0]>>>(ptr, 1);
    cuda_max_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(src, ptr, width, pitch, height);

}

__global__ void cuda_min_init_g(REAL *ptr, unsigned int height)
{
    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;

    if (j < height)
    {
        *ptr = 1000000;
    }
}

__global__ void cuda_min_g(REAL *src, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height)
{
    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int i = blockIdx.y * blockDim.y + threadIdx.y;

    if (i < width && j < height)
    {
        //if (j == 0) *ptr = *(src + i);

        atomicMin(ptr, *(src + i + j*pitch)); 
    }
}

void cuda_min(REAL *src, REAL *ptr, unsigned int width, unsigned int pitch, unsigned int height, dim3 blockset) {
    
    dim3 griddim, blockdim;
    set_block_grid2d(height, width, blockset, griddim, blockdim);

    cuda_min_init_g<<<1, 1, 0, Stream_opencc[0]>>>(ptr, 1);
    cuda_min_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(src, ptr, width, pitch, height);

}

__device__ REAL warpReduce(REAL mySum){
    mySum += __shfl_xor_double(mySum, 32, warpSize);
    mySum += __shfl_xor_double(mySum, 16, warpSize);
    mySum += __shfl_xor_double(mySum,  8, warpSize);
    mySum += __shfl_xor_double(mySum,  4, warpSize);
    mySum += __shfl_xor_double(mySum,  2, warpSize);
    mySum += __shfl_xor_double(mySum,  1, warpSize);
    return mySum;
}

__global__ void add0_kernel(REAL *p, int SMEMDIM, REAL *g_odata){
    extern __shared__ REAL shared[];
    unsigned int x = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int Id = threadIdx.x;
    unsigned int warpId  = Id / warpSize;
    unsigned int laneIdx = Id % warpSize;
    REAL grad_f0 = 0.;
    
    if(x < size_d){
        grad_f0 = access_data(p, x);
    }

    grad_f0 = warpReduce(grad_f0);

    if(laneIdx == 0) shared[warpId] = grad_f0;
    __syncthreads();

    grad_f0 = (Id < SMEMDIM)?shared[Id]:0;

    if(warpId == 0) grad_f0 = warpReduce(grad_f0);
    if(Id == 0) g_odata[blockIdx.x] = grad_f0;
}

__global__ void add1_kernel(REAL *g_odata, int g_odata_size){
    extern __shared__ REAL shared[];
    unsigned int x = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int warpId  = threadIdx.x / warpSize;
    unsigned int laneIdx = threadIdx.x % warpSize;

    REAL grad_f0 = 0.;
    if(threadIdx.x < 8) shared[threadIdx.x] = 0.;
    if(x < g_odata_size) grad_f0 = g_odata[x];

    grad_f0 = warpReduce(grad_f0);
    if(laneIdx == 0) shared[warpId] = grad_f0;
    __syncthreads();

    grad_f0 = (threadIdx.x < 8)?shared[laneIdx]:0;

    if(warpId == 0) grad_f0 = warpReduce(grad_f0);

    if(x >= gridDim.x) g_odata[x] = 0.0;

    if(threadIdx.x == 0) g_odata[blockIdx.x] = grad_f0;
}

void sum(size_t size, size_t size_all, REAL *P_d, REAL *result_h){

    dim3 griddim, blockdim;

    set_block_grid(size, 256, griddim, blockdim);

    REAL *g_odata;
    REAL *Sum = (REAL *)malloc(sizeof(REAL));

    unsigned int g_odata_size = griddim.x;
    CUDACHECK(( hipMalloc((REAL **)&g_odata, g_odata_size*sizeof(REAL)) ));

    int SMEMDIM = blockdim.x/warpSize;   //Warpsize is 64
    add0_kernel<<<griddim, blockdim, SMEMDIM*sizeof(REAL), Stream_opencc[0]>>>(P_d, SMEMDIM, g_odata);

    dim3 blockdim_sum(256);
    dim3 griddim_sum(g_odata_size); 

    do{
        griddim_sum.x = (griddim_sum.x + blockdim_sum.x - 1)/blockdim_sum.x;
        add1_kernel<<<griddim_sum, blockdim_sum, 8*sizeof(REAL), Stream_opencc[0]>>>(g_odata, g_odata_size);
    } while(griddim_sum.x > 1);

    CUDACHECK(( hipMemcpy(Sum, g_odata, sizeof(REAL), hipMemcpyDeviceToHost) ));
    CUDACHECK(( hipFree(g_odata) ));

#ifdef __USE_MPI__
    MPI_Allreduce(Sum, result_h, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
#else
    *result_h = *Sum;
#endif //__USE_MPI__

    *result_h /= size_all;
}

__global__ void get_id_index_g(REAL *T, int *real_index_ref, int *real_num_ref){
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    extern __shared__ int shared_int[];

    shared_int[threadIdx.x] = 0;

    int tmp = 0;

    if (s < size_d && access_data(T, s) >= Tsd) {

        shared_int[threadIdx.x] = 1;

        __syncthreads();

        for (int i = 0; i < threadIdx.x; i++) {
            tmp += shared_int[i];
        }

        access_data(real_index_ref, blockDim.x * blockIdx.x + tmp) = s;

        tmp += shared_int[threadIdx.x];
    }

    atomicMax(real_num_ref + blockIdx.x, tmp); 
}

__global__ void reset_T_g(REAL *T, REAL *T_origin, int *real_index_ref, int *real_num_ref, int *real_num_total_ref){
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    int tmp = 0;

    if (s < size_d && threadIdx.x < *(real_num_ref + blockIdx.x)) {
        int index = access_data(real_index_ref, blockDim.x * blockIdx.x + threadIdx.x);

        for(int i = 0; i < blockIdx.x; i++) {
            tmp += *(real_num_ref + i);
        }

        tmp += threadIdx.x;

        access_data(T, tmp) = access_data(T_origin, index);
    }

    atomicMax(real_num_total_ref, tmp + 1); 
}

__global__ void reset_P_g(REAL *T, REAL *T_origin, int *real_index_ref, int *real_num_ref){
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d && threadIdx.x < *(real_num_ref + blockIdx.x)) {
        int index = access_data(real_index_ref, blockDim.x * blockIdx.x + threadIdx.x);

        int tmp = 0;
        
        for(int i = 0; i < blockIdx.x; i++) {
            tmp += *(real_num_ref + i);
        }

        tmp += threadIdx.x;

        access_data(T, tmp) = access_data(T_origin, index);
    }
}

__global__ void reset_Y_g(REAL *Y, REAL *Y_origin, int *real_index_ref, int *real_num_ref){
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d && threadIdx.x < *(real_num_ref + blockIdx.x)) {
        int index = access_data(real_index_ref, blockDim.x * blockIdx.x + threadIdx.x);

        int tmp = 0;

        for(int i = 0; i < blockIdx.x; i++) {
            tmp += *(real_num_ref + i);
        }

        for(int i = 0; i < sp_num_d; i++) {
            access_sp_num_data(Y, i, threadIdx.x + tmp) = access_sp_num_data(Y_origin, i, index);
        }
    }
}

__host__ void get_id_index(int size, int sp_num, REAL *T, REAL *T_origin, REAL *P, REAL *P_origin, REAL *Y, REAL *Y_origin, int *real_index_ref, int *real_num_ref, int *real_num_total_ref){

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_real, griddim, blockdim);

    get_id_index_g<<<griddim, blockdim, blockdim.x*sizeof(int), Stream_opencc[0]>>>(T_origin, real_index_ref, real_num_ref);

    reset_T_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(T, T_origin, real_index_ref, real_num_ref, real_num_total_ref);

    reset_P_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(P, P_origin, real_index_ref, real_num_ref);

    reset_Y_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(Y, Y_origin, real_index_ref, real_num_ref);
}

__global__ void reconstructY_g(REAL *Y, REAL *Y_origin, int *real_index_ref, int *real_num_ref){
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d && threadIdx.x < *(real_num_ref + blockIdx.x)) {
        int index = access_data(real_index_ref, blockDim.x * blockIdx.x + threadIdx.x);

        int tmp = 0;

        for(int i = 0; i < blockIdx.x; i++) {
            tmp += *(real_num_ref + i);
        }

        for(int i = 0; i < sp_num_d; i++) {
            access_sp_num_data(Y_origin, i, index) = access_sp_num_data(Y, i, threadIdx.x + tmp);
        }
    }
}

__host__ void reconstructY(int size, int sp_num, REAL *Y, REAL *Y_origin, int *real_index_ref, int *real_num_ref){

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_real, griddim, blockdim);

    reconstructY_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(Y, Y_origin, real_index_ref, real_num_ref);
}