#include "hip/hip_runtime.h"
/*******************************************************************************************
 * This file contains the implementation of GPU functions related to reactions compute.
 ******************************************************************************************/

#include "species.h"
#include "thermoFluid.h"
#include "reactions.h"
#include "reactions_kernel.h"

__global__ void compute_T_P_g(REAL *c, REAL *T, REAL *P, REAL *dt_sum_d, REAL t_end_h) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d && access_data(dt_sum_d, s) <= t_end_h) {
        access_data(T, s) = access_vars_data(c, sp_num_d, s);
        access_data(P, s) = access_vars_data(c, sp_num_d+1, s);
    }
}

__host__ void compute_T_P_h(REAL *c, REAL *T, REAL *P) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_common, griddim, blockdim);

    compute_T_P_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(c, T, P, dt_sum_d, t_end_h);
}

__device__ REAL compute_B_d(REAL *nasa, REAL *T) {

    REAL tmp = 1/6.;

    tmp = (nasa[6] - nasa[0] +
        (nasa[0] - 1.)*log(*T) + 
        *T * (nasa[1]*0.5 + *T * (nasa[2]*tmp +
        *T * (nasa[3]*0.5*tmp + *T * nasa[4] * 0.05)))
        - nasa[5]/ *T);

    return tmp;
}

__global__ void compute_B_g(REAL *T, REAL *T_range, REAL *sp_nasa, REAL *B, REAL *dt_sum_d, REAL t_end_h) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y;

    if (w< sp_num_d && s < size_d && access_data(dt_sum_d, s) <= t_end_h) {
        REAL T_tmp   = access_data(T, s);
        REAL T_range_tmp = *(T_range + 1 + w*3);
        REAL Nasa_tmp[7];

        if (T_tmp >= T_range_tmp) {
            Nasa_tmp[0] = *(sp_nasa + 7  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 8  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 9  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 10 + w*14);
            Nasa_tmp[4] = *(sp_nasa + 11 + w*14);
            Nasa_tmp[5] = *(sp_nasa + 12 + w*14);
            Nasa_tmp[6] = *(sp_nasa + 13 + w*14);
            
            access_sp_num_data(B, w, s) = compute_B_d(&Nasa_tmp[0], &T_tmp);
        } else {
            Nasa_tmp[0] = *(sp_nasa + 0  + w*14);
            Nasa_tmp[1] = *(sp_nasa + 1  + w*14);
            Nasa_tmp[2] = *(sp_nasa + 2  + w*14);
            Nasa_tmp[3] = *(sp_nasa + 3  + w*14);
            Nasa_tmp[4] = *(sp_nasa + 4  + w*14);
            Nasa_tmp[5] = *(sp_nasa + 5  + w*14);
            Nasa_tmp[6] = *(sp_nasa + 6  + w*14);
            
            access_sp_num_data(B, w, s) = compute_B_d(&Nasa_tmp[0], &T_tmp);
        }
    }
}

__host__ void compute_B_h(REAL *T, REAL *T_range, REAL *sp_nasa, REAL *B) {

    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, sp_num, block_set_J02d, griddim, blockdim);

    compute_B_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(T, T_range, sp_nasa, B, dt_sum_d, t_end_h);

}

__device__ REAL compute_kf_kr_d(REAL *a, REAL *b, REAL *ea, REAL *T, REAL *v_net, REAL *B, int w, int s) {

    REAL tmp;

    REAL sum_v_neti = 0;
    REAL vB = 0;
    for (int i = 0; i < sp_num_d; i++) {
        sum_v_neti += *(v_net + w + react_num_d*i);
        vB += *(v_net + w + react_num_d*i) * access_sp_num_data(B, i, s);
    }

    tmp = *a * pow(*T, *b) * exp(-*ea / *T - vB) / pow(p_atm_d_R, sum_v_neti);

    return tmp;
}

__device__ REAL compute_kf_d(REAL *a, REAL *b, REAL *ea, REAL *T) {

    REAL tmp;

    //tmp = *a  * pow(*T, *b) * exp(-*ea / *T);
    //tmp = exp(log(*a) + *b * log(*T) - *ea / *T);
    tmp = exp(log(*a) + *b * log(*T) - *ea / *T);

    return tmp;
}

__global__ void compute_kf_kr_g(REAL *abe, REAL *T, REAL *v_net, REAL *is_rev, REAL *B, REAL *kf, REAL *kr, REAL *kf_low, REAL *dt_sum_d, REAL t_end_h) {

    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y;

    if (w< react_num_d && s < size_d && access_data(dt_sum_d, s) <= t_end_h) {

        REAL abe_tmp[6];
        abe_tmp[0] = *(abe + 0 + 6*w);
        abe_tmp[1] = *(abe + 1 + 6*w);
        abe_tmp[2] = *(abe + 2 + 6*w);
        //abe_tmp[3] = *(abe + 3 + 6*w);
        //abe_tmp[4] = *(abe + 4 + 6*w);
        //abe_tmp[5] = *(abe + 5 + 6*w);
        abe_tmp[3] = *(abe + 3 + 6*w) / abe_tmp[0];
        abe_tmp[4] = *(abe + 4 + 6*w) - abe_tmp[1];
        abe_tmp[5] = *(abe + 5 + 6*w) - abe_tmp[2];

        REAL T_tmp = access_data(T, s);

        access_react_num_data(kf, w, s) = abe_tmp[0] = compute_kf_d(&abe_tmp[0], &abe_tmp[1], &abe_tmp[2], &T_tmp);
        access_react_num_data(kf_low, w, s) = compute_kf_d(&abe_tmp[3], &abe_tmp[4], &abe_tmp[5], &T_tmp);
        
        if (is_rev[w] == 1) {

            REAL sum_v_neti = 0;
            REAL vB = 0;
            for (int i = 0; i < sp_num_d; i++) {
                sum_v_neti += *(v_net + w + react_num_d*i);
                vB += *(v_net + w + react_num_d*i) * access_sp_num_data(B, i, s);
            }

            abe_tmp[1] = pow(p_atm_d_R, sum_v_neti) * exp(vB);

            if (abe_tmp[1] == 0) {
                access_react_num_data(kr, w, s) = 0.;
            } else {
                access_react_num_data(kr, w, s) = abe_tmp[0]/abe_tmp[1];
            }

        } else {

            access_react_num_data(kr, w, s) = 0.;
        }
    }
}

__host__ void compute_kf_kr_h(REAL *abe, REAL *T, REAL *v_net, REAL *is_rev, REAL *B, REAL *kf, REAL *kr, REAL *kf_low) {
    if(thermoFluid_d_ptr == nullptr) MPI_PRINTF("\033[31mBEFORE SETTING UP THE REACTION CLASS, PLEASE CREATE THE THERMO SCOPE.\033[0m\n");
    
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t react_num = reactions_d_ptr->reactions_ptr_d.react_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, react_num, block_set_J02d, griddim, blockdim);

    compute_kf_kr_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(abe, T, v_net, is_rev, B, kf, kr, kf_low, dt_sum_d, t_end_h);
}

__device__ REAL compute_Rf(int w, int s, REAL *vf, REAL *kf, REAL *c) {

    REAL tmp, fi_x = 1.;

    for (int k = 0; k < sp_num_d; k++) {
        tmp = *(vf + w + k*react_num_d);
        if (tmp > 0) {
            fi_x *= pow(access_vars_data(c, k, s), tmp);
        }
    }

    tmp = access_react_num_data(kf, w, s) * fi_x;

    return tmp;
}

__device__ void compute_Fcent_Fi(REAL *T, REAL *pri, REAL *a_troe, REAL *T1_troe, REAL *T2_troe, REAL *T3_troe, REAL *F_cent, REAL *Fi) {

    REAL tmp = (1.0 - *a_troe) * exp(- *T / *T3_troe) + *a_troe * exp(- *T / *T1_troe) + exp(- *T2_troe / *T);

    *F_cent = tmp;

    REAL A_troe = log10f(fmax(*pri, 1e-40)) - 0.67*log10f(fmax(tmp, 1e-40)) - 0.4;
    REAL B_tore = 0.806 - 1.1762*log10f(fmax(tmp, 1e-40))-0.14*log10f(fmax(*pri, 1e-40));

    REAL tmp2 = (A_troe/B_tore);
    tmp2 = 1./((tmp2*tmp2) + 1.);

    *Fi = pow(tmp, tmp2);
}

__global__ void compute_react_c_g(REAL *tb_coeffs, REAL *c, REAL *T, REAL *kf, REAL *kr, REAL *kf_low, REAL *react_c, 
    REAL *react_type, REAL *fall_type, REAL *fall_coeffs, REAL *vr, REAL *vf, REAL *is_rev, REAL *q, REAL *dt_sum_d, REAL t_end_h) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;
    unsigned int w = blockDim.y * blockIdx.y + threadIdx.y; //react_num

    if (w< react_num_d && s < size_d && access_data(dt_sum_d, s) <= t_end_h) {
        REAL sum_c = 0, react_c_tmp;

        for (int k = 0; k < sp_num_d; k++) {
            sum_c += *(tb_coeffs + k + sp_num_d*w) * access_vars_data(c, k, s);
        }

        if (*(react_type + w) == 0) {
            access_react_num_data(react_c, w, s) = react_c_tmp = 1.; 
        }

        if (*(react_type + w) == 1) {
            access_react_num_data(react_c, w, s) = react_c_tmp = sum_c; 
        }

        if (*(react_type + w) == 2) {
            REAL T_tmp = access_data(T, s);
            REAL T_1 = 1. / T_tmp;

            REAL pri0, pri1;

            //pri0 = sum_c*access_react_num_data(kf_low, w, s)/access_react_num_data(kf, w, s);
            pri0 = sum_c*access_react_num_data(kf_low, w, s);

            pri1 = pri0/(1 + pri0);

            if (*(fall_type + w) == 1) {
                access_react_num_data(react_c, w, s) = react_c_tmp = pri1;
            }

            if (*(fall_type + w) == 2) {
                REAL fc_tmp[4];
                REAL F_cent, Fi;

                fc_tmp[0] = *(fall_coeffs + 5*w);
                fc_tmp[1] = *(fall_coeffs + 5*w + 1);
                fc_tmp[2] = *(fall_coeffs + 5*w + 2);
                fc_tmp[3] = *(fall_coeffs + 5*w + 3);

                compute_Fcent_Fi(&T_tmp, &pri0, 
                                 &fc_tmp[0], 
                                 &fc_tmp[1], 
                                 &fc_tmp[2], 
                                 &fc_tmp[3], 
                                &F_cent, &Fi);

                access_react_num_data(react_c, w, s) = react_c_tmp = pri1*Fi;
            }
        }

        REAL Rf = 0, Rr = 0;

        Rf = compute_Rf(w, s, vf, kf, c);

        if (is_rev[w] == 1) {
            Rr = compute_Rf(w, s, vr, kr, c);
        }

        access_react_num_data(q, w, s) = react_c_tmp * (Rf - Rr);

    }
}

__host__ void compute_react_c_h(REAL *tb_coeffs, REAL *c, REAL *T, REAL *kf, REAL *kr, REAL *kf_low, REAL *react_c, 
    REAL *react_type, REAL *fall_type, REAL *fall_coeffs, REAL *vr, REAL *vf, REAL *is_rev, REAL *q) {

    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t react_num = reactions_d_ptr->reactions_ptr_d.react_num;

    dim3 griddim, blockdim;

    set_block_grid2d(size, react_num, block_set_J02d, griddim, blockdim);

    compute_react_c_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(tb_coeffs, c, T, kf, kr, kf_low, react_c, 
                                                react_type, fall_type, fall_coeffs, vr, vf, is_rev, q, dt_sum_d, t_end_h);
}

__global__ void compute_dTdt_g(REAL *v_net, REAL *q, REAL *sp_ha_mole, REAL *sp_cp_mole, REAL *sp_net_rate, REAL *c, REAL *dcdt, REAL *dt_sum_d, REAL t_end_h) {
    unsigned int s = blockDim.x * blockIdx.x + threadIdx.x;

    if (s < size_d && access_data(dt_sum_d, s) <= t_end_h) {
        REAL tmp, tmp1 = 0., tmp2;

        for (int k = 0; k < sp_num_d; k++) {
            tmp = 0.;
            for (int i = 0; i < react_num_d; i++) {
                tmp += *(v_net + i + react_num_d*k) * access_react_num_data(q, i, s);
            }
            tmp1 += access_sp_num_data(sp_ha_mole, k, s)*tmp;
            tmp2 += access_sp_num_data(sp_cp_mole, k, s)*access_vars_data(c, k, s);

            access_vars_data(dcdt, k, s) = tmp;

        }

        access_vars_data(dcdt, sp_num_d, s) = -tmp1/tmp2;
        access_vars_data(dcdt, sp_num_d+1, s) = 0.;
    }
}

__host__ void compute_dTdt_h(REAL *v_net, REAL *q, REAL *sp_ha_mole, REAL *sp_cp_mole, REAL *sp_net_rate, REAL *c, REAL *dcdt) {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t react_num = reactions_d_ptr->reactions_ptr_d.react_num;

    dim3 griddim, blockdim;

    set_block_grid(size, block_set_J0, griddim, blockdim);

    compute_dTdt_g<<<griddim, blockdim, 0, Stream_opencc[0]>>>(v_net, q, sp_ha_mole, sp_cp_mole, sp_net_rate, c, dcdt, dt_sum_d, t_end_h);
}

