#include "hip/hip_runtime.h"
/*******************************************************************************************
 * This file contains the implementation of GPU functions related to thermo compute.
 ******************************************************************************************/

#include "species.h"
#include "thermoFluid.h"
#include "thermoFluid_kernel.h"

__global__ void update_c_from_T_P_g(REAL *c, REAL *T, REAL *P) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        access_vars_data(c, sp_num_d, s) = access_data(T, s);
        access_vars_data(c, sp_num_d+1, s) = access_data(P, s);
    }
}

__host__ void update_c_from_T_P_h(REAL *c, REAL *T, REAL *P) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_common, griddim, blockdim);

    update_c_from_T_P_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(c, T, P);
}

__global__ void update_T_P_from_c_g(REAL *c, REAL *T, REAL *P) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        access_data(T, s) = access_vars_data(c, sp_num_d, s);
        access_data(P, s) = access_vars_data(c, sp_num_d+1, s);
    }
}

__host__ void update_T_P_from_c_h(REAL *c, REAL *T, REAL *P) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_common, griddim, blockdim);

    update_T_P_from_c_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(c, T, P);
}

__global__ void formYupdateX_g(int size, REAL *sp_W, REAL *W_mix, REAL *X, REAL *Y)
{
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    extern __shared__ REAL shared[];

    if(s < size) {
        REAL X_total = 0., tmp;

        for (int i = 0; i < sp_num_d; i++) {
            shared[i + sp_num_d*threadIdx.x] = tmp = access_sp_num_data(Y, i, s) / *(sp_W + i);
            X_total += tmp;
        }

        access_data(W_mix, s) = tmp  = 1./X_total;

        for (int i = 0; i < sp_num_d; i++) {
            access_sp_num_data(X, i, s) = shared[i + sp_num_d*threadIdx.x]*tmp;
        }
    }
}

__host__ void formYupdateX_h(int size, REAL *sp_W, REAL *W_mix, REAL *X, REAL *Y)
{
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;
    set_block_grid(size, 64, griddim, blockdim);

    formYupdateX_g<<<griddim, blockdim, blockdim.x*sp_num*sizeof(REAL), Stream_opencc[0]>>>(size, sp_W, W_mix, X, Y);
}

__global__ void formXupdateY_g(int size, REAL *sp_W, REAL *W_mix, REAL *X, REAL *Y)
{
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    extern __shared__ REAL shared[];

    if(s < size) {
        REAL X_total = 0., tmp;

        for (int i = 0; i < sp_num_d; i++) {
            shared[i + sp_num_d*threadIdx.x] = tmp = access_sp_num_data(Y, i, s) * *(sp_W + i);
            X_total += tmp;
        }


        access_data(W_mix, s) = X_total;
        tmp  = 1./X_total;

        for (int i = 0; i < sp_num_d; i++) {
            access_sp_num_data(X, i, s) = shared[i + sp_num_d*threadIdx.x]*tmp;
        }
    }
}

__host__ void formXupdateY_h(int size, REAL *sp_W, REAL *W_mix, REAL *X, REAL *Y)
{
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;
    set_block_grid(size, 64, griddim, blockdim);

    formXupdateY_g<<<griddim, blockdim, blockdim.x*sp_num*sizeof(REAL), Stream_opencc[0]>>>(size, sp_W, W_mix, Y, X);
}

__global__ void get_rho_g(REAL *P, REAL *T, REAL *W_mix, REAL *rho) 
{
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        REAL P_tmp = access_data(P, s);
        REAL T_tmp = access_data(T, s);

        REAL W_mix_tmp = access_data(W_mix, s);

        access_data(rho, s) = P_tmp*W_mix_tmp / (R * T_tmp);
    }
}

__host__ void get_rho_h(REAL *P, REAL *T, REAL *W_mix, REAL *rho)
{
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;

    dim3 griddim, blockdim;
    set_block_grid(size, block_set_thermoFluid, griddim, blockdim);

    get_rho_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(P, T, W_mix, rho);
}

__global__ void get_c_g(REAL *rho, REAL *Y, REAL *sp_W, REAL *c) 
{
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y;

    if (w< sp_num_d && s < size_d) {
        REAL rho_tmp = access_data(rho, s);
        
        REAL Y_tmp = access_sp_num_data(Y, w, s);

        access_vars_data(c, w, s) = rho_tmp*Y_tmp / *(sp_W + w);
    }
}

__host__ void get_c_h(REAL *rho, REAL *Y, REAL *sp_W, REAL *c) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, sp_num, block_set_J02d, griddim, blockdim);

    get_c_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(rho, Y, sp_W, c);
}

__global__ void get_psi_g(REAL *rho, REAL *psi) 
{
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        REAL rho_tmp = access_data(rho, s);

        access_data(psi, s) = 1 / rho_tmp;
    }
}

__host__ void get_psi_h(REAL *rho, REAL *psi)
{
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;

    dim3 griddim, blockdim;
    set_block_grid(size, block_set_thermoFluid, griddim, blockdim);

    get_psi_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(rho, psi);
}

__device__ REAL compute_ha_d(REAL *nasa, REAL *T) {

    REAL tmp = *T * *T;

    tmp = (nasa[0]* *T + nasa[1]*0.5*tmp + nasa[2]*tmp* *T/3. +
           nasa[3]* tmp*tmp*0.25 + nasa[4]* tmp*tmp* *T*0.2 +
           nasa[5])*8313.8462;

    return tmp;
}

__device__ REAL compute_cp_d(REAL *nasa, REAL *T) {
    REAL tmp = *T * *T;

    tmp = (nasa[0] + nasa[1]* *T + nasa[2]*tmp +
           nasa[3]* tmp * *T + nasa[4]* tmp * tmp)*8313.8462;

    return tmp;
}

__global__ void compute_hc_mass_g(REAL *sp_nasa, REAL *sp_W, REAL *hc_mass) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y;

    if (w< sp_num_d && s < size_d) {
        REAL T_tmp = Tsd;
        REAL Nasa_tmp[6];

        Nasa_tmp[0] = *(sp_nasa + 0  + w*14);
        Nasa_tmp[1] = *(sp_nasa + 1  + w*14);
        Nasa_tmp[2] = *(sp_nasa + 2  + w*14);
        Nasa_tmp[3] = *(sp_nasa + 3  + w*14);
        Nasa_tmp[4] = *(sp_nasa + 4  + w*14);
        Nasa_tmp[5] = *(sp_nasa + 5  + w*14);
            
        access_sp_num_data(hc_mass, w, 0) = compute_ha_d(&Nasa_tmp[0], &T_tmp) / *(sp_W + w);
    }
}

__host__ void compute_hc_mass_h(REAL *sp_nasa, REAL *sp_W, REAL *hc_mass) {
    size_t size = 1;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, sp_num, block_set_J02d, griddim, blockdim);

    compute_hc_mass_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(sp_nasa, sp_W, hc_mass);
}

__global__ void compute_ha_cp_g(REAL *T, REAL *T_range, REAL *sp_nasa, REAL *sp_ha_mole, REAL *sp_cp_mole, REAL *dt_sum_d, REAL t_end_h) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y;

    if (w< sp_num_d && s < size_d && access_data(dt_sum_d, s) <= t_end_h) {
        REAL T_tmp   = access_data(T, s);
        REAL T_range_tmp = *(T_range + 1 + w*3);
        REAL Nasa_tmp[6];

        if (T_tmp >= T_range_tmp) {
            Nasa_tmp[0] = *(sp_nasa + 7  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 8  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 9  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 10 + w*14);
            Nasa_tmp[4] = *(sp_nasa + 11 + w*14);
            Nasa_tmp[5] = *(sp_nasa + 12 + w*14);
            
            access_sp_num_data(sp_ha_mole, w, s) = compute_ha_d(&Nasa_tmp[0], &T_tmp);
            access_sp_num_data(sp_cp_mole, w, s) = compute_cp_d(&Nasa_tmp[0], &T_tmp);
        } else {
            Nasa_tmp[0] = *(sp_nasa + 0  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 1  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 2  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 3  + w*14);
            Nasa_tmp[4] = *(sp_nasa + 4  + w*14);
            Nasa_tmp[5] = *(sp_nasa + 5  + w*14);
            
            access_sp_num_data(sp_ha_mole, w, s) = compute_ha_d(&Nasa_tmp[0], &T_tmp);
            access_sp_num_data(sp_cp_mole, w, s) = compute_cp_d(&Nasa_tmp[0], &T_tmp);
        }
    }
}

__host__ void compute_ha_cp_h(REAL *T, REAL *T_range, REAL *sp_nasa, REAL *sp_ha_mole, REAL *sp_cp_mole) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, sp_num, block_set_J02d, griddim, blockdim);

    compute_ha_cp_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(T, T_range, sp_nasa, sp_ha_mole, sp_cp_mole, dt_sum_d, t_end_h);
}

__device__ REAL compute_ha_mass(REAL *sp_ha_mole, REAL *sp_W, REAL *Y, unsigned int *s) {
    REAL tmp = 0;

    for (int i = 0; i < sp_num_d; i++) {
        tmp += access_sp_num_data(Y, i, *s) * access_sp_num_data(sp_ha_mole, i, *s) / *(sp_W + i);
    }

    return tmp;
}

__global__ void compute_ha_mass_g(REAL *sp_ha_mole, REAL *sp_W, REAL *Y, REAL *ha_mass) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        access_data(ha_mass, s) = compute_ha_mass(sp_ha_mole, sp_W, Y, &s);
    }
}

__host__ void compute_ha_mass_h(REAL *sp_ha_mole, REAL *sp_W, REAL *Y, REAL *ha_mass) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_thermoFluid, griddim, blockdim);

    compute_ha_mass_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(sp_ha_mole, sp_W, Y, ha_mass);
}

__device__ REAL compute_ha_mass_ofc_g(REAL *T_tmp, REAL *T_range, REAL *sp_nasa, REAL *sp_W, REAL *Y) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    REAL ha = 0.;
    REAL T_range_tmp;
    REAL sp_ha_mole;
    REAL Nasa_tmp[6];

    for (int w = 0; w < sp_num_d; w++) {

        T_range_tmp = *(T_range + 1 + w*3);

        if (*T_tmp >= T_range_tmp) {
            Nasa_tmp[0] = *(sp_nasa + 7  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 8  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 9  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 10 + w*14);
            Nasa_tmp[4] = *(sp_nasa + 11 + w*14);
            Nasa_tmp[5] = *(sp_nasa + 12 + w*14);
            
            sp_ha_mole = compute_ha_d(&Nasa_tmp[0], T_tmp);
            ha += access_sp_num_data(Y, w, s) * sp_ha_mole / *(sp_W + w);
        } else {
            Nasa_tmp[0] = *(sp_nasa + 0  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 1  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 2  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 3  + w*14);
            Nasa_tmp[4] = *(sp_nasa + 4  + w*14);
            Nasa_tmp[5] = *(sp_nasa + 5  + w*14);
            
            sp_ha_mole = compute_ha_d(&Nasa_tmp[0], T_tmp);
            ha += access_sp_num_data(Y, w, s) * sp_ha_mole / *(sp_W + w);
        }
    }

    return ha;
}

__device__ REAL compute_cp_mass_ofc_g(REAL *T_tmp, REAL *T_range, REAL *sp_nasa, REAL *sp_W, REAL *Y) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    REAL cp = 0.;
    REAL T_range_tmp;
    REAL sp_cp_mole;
    REAL Nasa_tmp[6];

    for (int w = 0; w < sp_num_d; w++) {

        T_range_tmp = *(T_range + 1 + w*3);

        if (*T_tmp >= T_range_tmp) {
            Nasa_tmp[0] = *(sp_nasa + 7  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 8  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 9  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 10 + w*14);
            Nasa_tmp[4] = *(sp_nasa + 11 + w*14);
            Nasa_tmp[5] = *(sp_nasa + 12 + w*14);
            
            sp_cp_mole = compute_cp_d(&Nasa_tmp[0], T_tmp);
            cp += access_sp_num_data(Y, w, s) * sp_cp_mole / *(sp_W + w);
        } else {
            Nasa_tmp[0] = *(sp_nasa + 0  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 1  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 2  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 3  + w*14);
            Nasa_tmp[4] = *(sp_nasa + 4  + w*14);
            Nasa_tmp[5] = *(sp_nasa + 5  + w*14);
            
            sp_cp_mole = compute_cp_d(&Nasa_tmp[0], T_tmp);
            cp += access_sp_num_data(Y, w, s) * sp_cp_mole / *(sp_W + w);
        }
    }

    return cp;
}

__global__ void enthalpy_to_temperature_g(REAL *c, REAL *T_range, REAL *sp_nasa, REAL *Y, REAL *sp_W, REAL *h_target, REAL *T) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        REAL ha_mass0, ha_mass1;
        REAL cp_mass0;
        REAL ha_ta = access_data(h_target, s);
        REAL T0 = access_vars_data(c, sp_num_d, s), T1 = T0;
        REAL T0_1 = 1./T0;
        size_t n = 0;

        for (int i = 0; i < 20; i++) {
            T0 = T1;

            ha_mass0 = compute_ha_mass_ofc_g(&T0, T_range, sp_nasa, sp_W, Y);
            
            cp_mass0 = compute_cp_mass_ofc_g(&T1, T_range, sp_nasa, sp_W, Y);

            T1 = T0 - (ha_mass0 - ha_ta)/cp_mass0;

            n += 1;
        };
        
        access_data(T, s) = T1;
    }
}

__host__ void enthalpy_to_temperature_h(REAL *c, REAL *T_range, REAL *sp_nasa, REAL *Y, REAL *sp_W, REAL *h_target, REAL *T) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_thermoFluid, griddim, blockdim);

    enthalpy_to_temperature_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(c, T_range, sp_nasa, Y, sp_W, h_target, T);
}

__global__ void h_constraint_g(REAL *X, REAL *c) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d) {
        REAL tmp = 0.;

        for (int w = 0; w < sp_num_d; w++) {
            tmp += access_vars_data(c, w, s);
        }

        for (int w = 0; w < sp_num_d; w++) {
            access_sp_num_data(X, w, s) = access_vars_data(c, w, s)/ tmp;
        }
    }
}

__host__ void h_constraint_h(REAL *X, REAL *c) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_thermoFluid, griddim, blockdim);

    h_constraint_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(X, c);
}

__global__ void c_constraint_g(REAL *c) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y;

    if (w< sp_num_d && s < size_d) {

        access_vars_data(c, w, s) = fmax(0.0, access_vars_data(c, w, s));
    }
}

__host__ void c_constraint_h(REAL *c) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, sp_num, block_set_J02d, griddim, blockdim);

    c_constraint_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(c);
}

