/*******************************************************************************************
 * This file contains the parameters and function implementation related to reactions.
 ******************************************************************************************/

#include "thermoFluid.h"
#include "species.h"
#include "seulex.h"

#include "reactions.h"
#include "reactions_kernel.h"

/** Global variables for thermoFluid_d class types.
 */
reactions_d *reactions_d_ptr;

/** Construct class reactions_d.
 *  
 *  \param[in] reactions_ptr_d_ref  Pointer to device memory.
 */
reactions_d::reactions_d(reactions_ptr *reactions_ptr_d_ref):jacobian_d() {

    reactions_ptr_d.react_num   = reactions_ptr_d_ref->react_num;
      
}

/** Destory class reactions_d.
 */
reactions_d::~reactions_d() {

    if(reactions_ptr_d.vf          != nullptr) CUDACHECK(hipFree(reactions_ptr_d.vf         ))
    if(reactions_ptr_d.vr          != nullptr) CUDACHECK(hipFree(reactions_ptr_d.vr         ))
    if(reactions_ptr_d.v_net       != nullptr) CUDACHECK(hipFree(reactions_ptr_d.v_net      ))
    if(reactions_ptr_d.react_type  != nullptr) CUDACHECK(hipFree(reactions_ptr_d.react_type ))
    if(reactions_ptr_d.is_rev      != nullptr) CUDACHECK(hipFree(reactions_ptr_d.is_rev     ))
    if(reactions_ptr_d.abe         != nullptr) CUDACHECK(hipFree(reactions_ptr_d.abe        ))
    if(reactions_ptr_d.tb_coeffs   != nullptr) CUDACHECK(hipFree(reactions_ptr_d.tb_coeffs  ))
    if(reactions_ptr_d.fall_type   != nullptr) CUDACHECK(hipFree(reactions_ptr_d.fall_type  ))
    if(reactions_ptr_d.fall_coeffs != nullptr) CUDACHECK(hipFree(reactions_ptr_d.fall_coeffs))
    if(reactions_ptr_d.order       != nullptr) CUDACHECK(hipFree(reactions_ptr_d.order      ))
    if(reactions_ptr_d.kf          != nullptr) CUDACHECK(hipFree(reactions_ptr_d.kf         ))
    if(reactions_ptr_d.kf_low      != nullptr) CUDACHECK(hipFree(reactions_ptr_d.kf_low     ))
    if(reactions_ptr_d.kr          != nullptr) CUDACHECK(hipFree(reactions_ptr_d.kr         ))
    if(reactions_ptr_d.Rf          != nullptr) CUDACHECK(hipFree(reactions_ptr_d.Rf         ))
    if(reactions_ptr_d.Rr          != nullptr) CUDACHECK(hipFree(reactions_ptr_d.Rr         ))
    if(reactions_ptr_d.R_net       != nullptr) CUDACHECK(hipFree(reactions_ptr_d.R_net      ))
    if(reactions_ptr_d.react_c     != nullptr) CUDACHECK(hipFree(reactions_ptr_d.react_c    ))
    if(reactions_ptr_d.q           != nullptr) CUDACHECK(hipFree(reactions_ptr_d.q          ))
    if(reactions_ptr_d.sp_net_rate != nullptr) CUDACHECK(hipFree(reactions_ptr_d.sp_net_rate))
    if(reactions_ptr_d.dcdt        != nullptr) CUDACHECK(hipFree(reactions_ptr_d.dcdt       ))
}

/** Set class reactions_d.
 *  
 *  \param[in] react_num          chemical reaction number. (host)
 *  \param[in] reactions_ptr_ref  The memory address of the host side corresponding to 
 *                                the packaged storage reactions parameters. (host)
 *  \param[in] op                 Data processing mode, see THERMO_SET_MODE for more. (host) 
 */
void reactions_d::reactions_d_set(int react_num, reactions_ptr *reactions_ptr_ref, REACTION_SET_MODE op) {
    
    if(thermoFluid_d_ptr == nullptr) MPI_PRINTF("\033[31mBEFORE SETTING UP THE REACTION CLASS, PLEASE CREATE THE THERMO SCOPE.\033[0m\n");
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;
    reactions_ptr_d.react_num = react_num;

    size_t pitch, tmp;

    jacobian_init();
    seulex_d_ptr->seulex_init();

    if (op == vf_MODE          || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.vf         , reactions_ptr_ref->vf         , sizeof(REAL)*sp_num*react_num);
    }
    if (op == vr_MODE          || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.vr         , reactions_ptr_ref->vr         , sizeof(REAL)*sp_num*react_num);
    }
    if (op == v_net_MODE       || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.v_net      , reactions_ptr_ref->v_net      , sizeof(REAL)*sp_num*react_num);
    }
    if (op == tb_coeffs_MODE   || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.tb_coeffs  , reactions_ptr_ref->tb_coeffs  , sizeof(REAL)*sp_num*react_num);
    }

    if (op == order_MODE       || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.order      , reactions_ptr_ref->order      , sizeof(REAL)*react_num);
    }

    if (op == react_type_MODE  || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.react_type , reactions_ptr_ref->react_type , sizeof(REAL)*react_num);
    }
    if (op == is_rev_MODE      || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.is_rev     , reactions_ptr_ref->is_rev     , sizeof(REAL)*react_num);
    }
    if (op == fall_type_MODE   || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.fall_type  , reactions_ptr_ref->fall_type  , sizeof(REAL)*react_num);    
    }

    if (op == abe_MODE         || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.abe        , reactions_ptr_ref->abe        , sizeof(REAL)*react_num*6);
    }
    if (op == fall_coeffs_MODE || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.fall_coeffs, reactions_ptr_ref->fall_coeffs, sizeof(REAL)*react_num*5);
    }

    //--------------------------------------------------------------------------------------
    if (op == kf_MODE          || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.kf         , reactions_ptr_ref->kf         , sizeof(REAL)*react_num*size);
    }
    if (op == kf_low_MODE      || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.kf_low     , reactions_ptr_ref->kf_low     , sizeof(REAL)*size*react_num);
    }
    if (op == kr_MODE          || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.kr         , reactions_ptr_ref->kr         , sizeof(REAL)*react_num*size);
    }
    if (op == Rf_MODE          || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.Rf         , reactions_ptr_ref->Rf         , sizeof(REAL)*size*react_num);
    }
    if (op == Rr_MODE          || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.Rr         , reactions_ptr_ref->Rr         , sizeof(REAL)*size*react_num);
    }
    if (op == R_net_MODE       || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.R_net      , reactions_ptr_ref->R_net      , sizeof(REAL)*size*react_num);
    }
    if (op == react_c_MODE     || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.react_c    , reactions_ptr_ref->react_c    , sizeof(REAL)*react_num*size);
    }
    if (op == q_MODE           || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.q          , reactions_ptr_ref->q          , sizeof(REAL)*size*react_num);
    }

    if (op == sp_net_rate_MODE || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.sp_net_rate, reactions_ptr_ref->sp_net_rate, sizeof(REAL)*size*sp_num);
    }

    if (op == dcdt_MODE        || op == REACTIONS_ALL) {
        DeviceDataset(reactions_ptr_d.dcdt       , reactions_ptr_ref->dcdt       , sizeof(REAL)*size*(sp_num+2));
    }

    pitch_react_num = pitch;

    tmp = react_num; 

    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(react_num_d)), &tmp, sizeof(size_t), 0, hipMemcpyHostToDevice));
    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(pitch_react_num_d)), &tmp, sizeof(size_t), 0, hipMemcpyHostToDevice));
    
}

/** Query class reactions_d.
 *  
 *  \param[in] reactions_ptr_ref  The memory address of the host side corresponding to 
 *                              the packaged storage reactions parameters. (host)
 *  \param[in] op               Data processing mode, see THERMO_SET_MODE for more. (host) 
 */
void reactions_d::reactions_d_get(reactions_ptr *reactions_ptr_ref, REACTION_SET_MODE op) {
    
    if(thermoFluid_d_ptr == nullptr) MPI_PRINTF("\033[31mBEFORE GETTING THE REACTION PARAMETERS, PLEASE CREATE THE THERMO SCOPE.\033[0m\n");

    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;
    size_t react_num = reactions_ptr_d.react_num;

    if (op == vf_MODE          || op == REACTIONS_ALL)
        HostDataget(reactions_ptr_h.vf         , reactions_ptr_ref->vf         , sizeof(REAL)*sp_num*react_num);
    if (op == vr_MODE          || op == REACTIONS_ALL)
        HostDataget(reactions_ptr_h.vr         , reactions_ptr_ref->vr         , sizeof(REAL)*sp_num*react_num);
    if (op == v_net_MODE       || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.v_net      , reactions_ptr_ref->v_net      , sizeof(REAL)*sp_num*react_num);
    if (op == tb_coeffs_MODE   || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.tb_coeffs  , reactions_ptr_ref->tb_coeffs  , sizeof(REAL)*sp_num*react_num);

    if (op == order_MODE       || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.order      , reactions_ptr_ref->order      , sizeof(REAL)*react_num);

    if (op == react_type_MODE  || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.react_type , reactions_ptr_ref->react_type , sizeof(REAL)*react_num);
    if (op == is_rev_MODE      || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.is_rev     , reactions_ptr_ref->is_rev     , sizeof(REAL)*react_num);
    if (op == fall_type_MODE   || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.fall_type  , reactions_ptr_ref->fall_type  , sizeof(REAL)*react_num);

    if (op == abe_MODE         || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.abe        , reactions_ptr_ref->abe        , sizeof(REAL)*react_num*6);
    if (op == fall_coeffs_MODE || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.fall_coeffs, reactions_ptr_ref->fall_coeffs, sizeof(REAL)*react_num*5);
    
    //--------------------------------------------------------------------------------------
    if (op == kf_MODE          || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.kf         , reactions_ptr_ref->kf         , size*sizeof(REAL)*react_num);
    if (op == kf_low_MODE      || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.kf_low     , reactions_ptr_ref->kf_low     , size*sizeof(REAL)*react_num);
    if (op == kr_MODE          || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.kr         , reactions_ptr_ref->kr         , size*sizeof(REAL)*react_num);
    if (op == Rf_MODE          || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.Rf         , reactions_ptr_ref->Rf         , sizeof(REAL)*size*react_num);
    if (op == Rr_MODE          || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.Rr         , reactions_ptr_ref->Rr         , sizeof(REAL)*size*react_num);
    if (op == R_net_MODE       || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.R_net      , reactions_ptr_ref->R_net      , sizeof(REAL)*size*react_num);
    if (op == react_c_MODE     || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.react_c    , reactions_ptr_ref->react_c    , size*sizeof(REAL)*react_num);
    if (op == q_MODE           || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.q          , reactions_ptr_ref->q          , sizeof(REAL)*size*react_num);
    if (op == sp_net_rate_MODE || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.sp_net_rate, reactions_ptr_ref->sp_net_rate, sizeof(REAL)*size*sp_num);
    if (op == dcdt_MODE        || op == REACTIONS_ALL)
        DeviceDataget(reactions_ptr_d.dcdt    , reactions_ptr_ref->dcdt    , size*sizeof(REAL)*(sp_num+2));
}

void reactions_d::cal_react_rate(REAL *y, REAL *dy) {

    compute_T_P_h(y, thermoFluid_d_ptr->thermo_ptr_d.T,
                thermoFluid_d_ptr->thermo_ptr_d.P);

    compute_B_h(thermoFluid_d_ptr->thermo_ptr_d.T, 
                species_d_ptr->species_const_d.T_range,
                species_d_ptr->species_const_d.sp_nasa,
                reactions_ptr_d.sp_net_rate);

    compute_kf_kr_h(reactions_ptr_d.abe,
                    thermoFluid_d_ptr->thermo_ptr_d.T,
                    reactions_ptr_d.v_net,
                    reactions_ptr_d.is_rev,
                    reactions_ptr_d.sp_net_rate,
                    reactions_ptr_d.kf,
                    reactions_ptr_d.kr,
                    reactions_ptr_d.kf_low);

    compute_react_c_h(reactions_ptr_d.tb_coeffs,
                    y, 
                    thermoFluid_d_ptr->thermo_ptr_d.T,
                    reactions_ptr_d.kf,
                    reactions_ptr_d.kr,
                    reactions_ptr_d.kf_low,
                    reactions_ptr_d.react_c,
                    reactions_ptr_d.react_type,
                    reactions_ptr_d.fall_type,
                    reactions_ptr_d.fall_coeffs,
                    reactions_ptr_d.vr,
                    reactions_ptr_d.vf,
                    reactions_ptr_d.is_rev,
                    reactions_ptr_d.q);

    thermoFluid_d_ptr->thermoFluid_compute_ha_cp();

    compute_dTdt_h(reactions_ptr_d.v_net,
                    reactions_ptr_d.q,
                    thermoFluid_d_ptr->thermo_ptr_d.sp_ha_mole,
                    thermoFluid_d_ptr->thermo_ptr_d.sp_cp_mole,
                    reactions_ptr_d.sp_net_rate,
                    y,
                    dy);
}