/*******************************************************************************************
 * This file contains the parameters and function implementation related to thermophysical
 * properties.
 ******************************************************************************************/

#include "species.h"
#include "thermoFluid.h"
#include "thermoFluid_kernel.h"

//__constant__ size_t a_d;

/** Global variables for thermoFluid_d class types.
 */
thermoFluid_d *thermoFluid_d_ptr;

/** Construct class thermoFluid_d.
 *  
 *  \param[in] thermo_ptr_d_ref  Pointer to device memory.
 */
thermoFluid_d::thermoFluid_d(thermo_ptr *thermo_ptr_d_ref) {

    thermo_ptr_d.T          = thermo_ptr_d_ref->T;
    thermo_ptr_d.P          = thermo_ptr_d_ref->P;
    thermo_ptr_d.Y          = thermo_ptr_d_ref->Y;
    thermo_ptr_d.X          = thermo_ptr_d_ref->X;
    //---------------------- Above data from CFD

    thermo_ptr_d.rho        = thermo_ptr_d_ref->rho;    
    thermo_ptr_d.W_mix      = thermo_ptr_d_ref->W_mix;  
    thermo_ptr_d.ha_mass    = thermo_ptr_d_ref->ha_mass;
    thermo_ptr_d.ha_mole    = thermo_ptr_d_ref->ha_mole;
    thermo_ptr_d.hc_mass    = thermo_ptr_d_ref->hc_mass;
    thermo_ptr_d.hc_mole    = thermo_ptr_d_ref->hc_mole;
    thermo_ptr_d.hs_mass    = thermo_ptr_d_ref->hs_mass;
    thermo_ptr_d.hs_mole    = thermo_ptr_d_ref->hs_mole;
    thermo_ptr_d.cp_mole    = thermo_ptr_d_ref->cp_mole;
    thermo_ptr_d.psi        = thermo_ptr_d_ref->psi;    

    thermo_ptr_d.c          = thermo_ptr_d_ref->c;         
    thermo_ptr_d.sp_ha_mole = thermo_ptr_d_ref->sp_ha_mole;
    thermo_ptr_d.sp_cp_mole = thermo_ptr_d_ref->sp_cp_mole;
}

/** Destory class thermoFluid_d.
 */
thermoFluid_d::~thermoFluid_d() {

    if(thermo_ptr_d.T          != nullptr) CUDACHECK(hipFree(thermo_ptr_d.T))
    if(thermo_ptr_d.P          != nullptr) CUDACHECK(hipFree(thermo_ptr_d.P))
    if(thermo_ptr_d.Y          != nullptr) CUDACHECK(hipFree(thermo_ptr_d.Y))
    if(thermo_ptr_d.X          != nullptr) CUDACHECK(hipFree(thermo_ptr_d.X))
    //---------------------- Above data from CFD

    if(thermo_ptr_d.rho        != nullptr) CUDACHECK(hipFree(thermo_ptr_d.rho))    
    if(thermo_ptr_d.W_mix      != nullptr) CUDACHECK(hipFree(thermo_ptr_d.W_mix))  
    if(thermo_ptr_d.ha_mass    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.ha_mass))
    if(thermo_ptr_d.ha_mole    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.ha_mole))
    if(thermo_ptr_d.hc_mass    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.hc_mass))
    if(thermo_ptr_d.hc_mole    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.hc_mole))
    if(thermo_ptr_d.hs_mass    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.hs_mass))
    if(thermo_ptr_d.hs_mole    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.hs_mole))
    if(thermo_ptr_d.cp_mole    != nullptr) CUDACHECK(hipFree(thermo_ptr_d.cp_mole))
    if(thermo_ptr_d.psi        != nullptr) CUDACHECK(hipFree(thermo_ptr_d.psi))    

    if(thermo_ptr_d.c          != nullptr) CUDACHECK(hipFree(thermo_ptr_d.c))         
    if(thermo_ptr_d.sp_ha_mole != nullptr) CUDACHECK(hipFree(thermo_ptr_d.sp_ha_mole))
    if(thermo_ptr_d.sp_cp_mole != nullptr) CUDACHECK(hipFree(thermo_ptr_d.sp_cp_mole))
}

/** Set class thermoFluid_d.
 *  
 *  \param[in] sp_num          The number of species. (host)
 *  \param[in] size            Number of grid points. (host)
 *  \param[in] thermo_ptr_ref  The memory address of the host side corresponding to 
 *                             the packaged storage thermo parameters. (host)
 *  \param[in] op              Data processing mode, see THERMO_SET_MODE for more. (host) 
 */
void thermoFluid_d::thermoFluid_d_set(int sp_num, int size, thermo_ptr *thermo_ptr_ref, THERMO_SET_MODE op) {

    size_t tmp;

    if (op == T_MODE          || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.T,          thermo_ptr_ref->T,          size*sizeof(REAL));;
    }
    if (op == P_MODE          || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.P,          thermo_ptr_ref->P,          size*sizeof(REAL));
    }
    if (op == Y_MODE          || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.Y,        thermo_ptr_ref->Y,          size*sizeof(REAL)*sp_num);

#ifdef DEBUG
        printf("sp_num is %d, size is %d\n", sp_num, size);
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < sp_num; j++) {
                 MPI_PRINTF("Y[%d][%d]: %e  ", i, j, thermo_ptr_h.Y[j + i*sp_num]);
            }
            MPI_PRINTF("\n");
        }
#endif //DEBUG
    
    }
    if (op == X_MODE          || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.X,          thermo_ptr_ref->X,          size*sizeof(REAL)*sp_num);
    }
        //---------------------- Above data from CFD

    if (op == rho_MODE        || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.rho,        thermo_ptr_ref->rho,        size*sizeof(REAL));
    }
    if (op == W_mix_MODE      || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.W_mix,      thermo_ptr_ref->W_mix,      size*sizeof(REAL));
    }
    if (op == ha_mass_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.ha_mass,    thermo_ptr_ref->ha_mass,    size*sizeof(REAL));
    }
    if (op == ha_mole_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.ha_mole,    thermo_ptr_ref->ha_mole,    size*sizeof(REAL));
    }
    if (op == hc_mass_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.hc_mass,    thermo_ptr_ref->hc_mass,    sp_num*sizeof(REAL));
    }
    if (op == hc_mole_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.hc_mole,    thermo_ptr_ref->hc_mole,    size*sizeof(REAL));
    }
    if (op == hs_mass_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.hs_mass,    thermo_ptr_ref->hs_mass,    size*sizeof(REAL));
    }
    if (op == hs_mole_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.hs_mole,    thermo_ptr_ref->hs_mole,    size*sizeof(REAL));
    }
    if (op == cp_mole_MODE    || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.cp_mole,    thermo_ptr_ref->cp_mole,    size*sizeof(REAL));
    }
    if (op == psi_MODE        || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.psi,        thermo_ptr_ref->psi,        size*sizeof(REAL));
    }

    if (op == c_MODE          || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.c,          thermo_ptr_ref->c,          size*sizeof(REAL)*(sp_num+2));
    }
    if (op == sp_ha_mole_MODE || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.sp_ha_mole, thermo_ptr_ref->sp_ha_mole, size*sizeof(REAL)*sp_num);
    }
    if (op == sp_cp_mole_MODE || op == THERMO_ALL) {
        DeviceDataset(thermo_ptr_d.sp_cp_mole, thermo_ptr_ref->sp_cp_mole, size*sizeof(REAL)*sp_num);
    }

    if (op == THERMO_ALL) {
        thermo_ptr_d.sp_num = sp_num;
        thermo_ptr_d.size   = size;

        tmp = sp_num; CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(sp_num_d)), &tmp, sizeof(size_t)));
                      CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(pitch_sp_num_d)), &tmp, sizeof(size_t)));
        tmp = size;   CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(size_d)), &tmp, sizeof(size_t)));

        CUDACHECK(hipMalloc((void**)&dt_sum_d, sizeof(REAL)*size));

        cuda_mem_value_init(0.0, dt_sum_d, 1, 1, size, block_set_J0);

        CUDACHECK(hipMalloc((void**)&dt_sum_d, sizeof(REAL)*size));
        CUDACHECK(hipMalloc((void**)&real_index, sizeof(int)*size));

        dim3 griddim, blockdim;
        set_block_grid(size, block_set_real, griddim, blockdim);

        CUDACHECK(hipMalloc((void**)&real_num, sizeof(int)*griddim.x));

        CUDACHECK(hipMalloc((void**)&real_num_total, sizeof(int)));

        DeviceDataset(T_origin,        thermo_ptr_ref->T,          size*sizeof(REAL));
        DeviceDataset(P_origin,        thermo_ptr_ref->P,          size*sizeof(REAL));
        DeviceDataset(Y_origin,        thermo_ptr_ref->Y,          size*sizeof(REAL)*sp_num);

        size_origin = size;
    }
}

void thermoFluid_d::set_TPY(REAL *T, REAL *P, REAL *Y, int sp_num, int size, DATA_MODE op) {

    switch(op) {
        case CPU:
        CUDACHECK(hipMemcpyAsync(T_origin, T, size*sizeof(REAL), hipMemcpyHostToDevice, Stream_opencc[0]));
        CUDACHECK(hipMemcpyAsync(P_origin, P, size*sizeof(REAL), hipMemcpyHostToDevice, Stream_opencc[0]));
        CUDACHECK(hipMemcpyAsync(Y_origin, Y, size*sp_num*sizeof(REAL), hipMemcpyHostToDevice, Stream_opencc[0]));
        break;

        case GPU:
        CUDACHECK(hipMemcpyAsync(T_origin, T, size*sizeof(REAL), hipMemcpyDeviceToDevice, Stream_opencc[0]));
        CUDACHECK(hipMemcpyAsync(P_origin, P, size*sizeof(REAL), hipMemcpyDeviceToDevice, Stream_opencc[0]));
        CUDACHECK(hipMemcpyAsync(Y_origin, Y, size*sp_num*sizeof(REAL), hipMemcpyDeviceToDevice, Stream_opencc[0]));
        break;

        default:
        MPI_PRINTF("\033[31mWORRY!!! WHEN SETTING DATA_MODE, DATA_MODE SETTING ERROR.\033[0m\n");
        break;
    }


    get_id_index(size, sp_num, thermo_ptr_d.T, T_origin, thermo_ptr_d.P, P_origin, thermo_ptr_d.Y, Y_origin, real_index, real_num, real_num_total);

    CUDACHECK(hipMemcpy(&thermo_ptr_d.size, real_num_total, sizeof(int), hipMemcpyDeviceToHost));

    size_t tmp = thermo_ptr_d.size;

    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(size_d)), &tmp, sizeof(size_t)));
}

void thermoFluid_d::get_Y(REAL *Y, int sp_num, int size, DATA_MODE op) {

    size_t tmp = size_origin;

    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(size_d)), &tmp, sizeof(size_t)));

    thermo_ptr_d.size = tmp;

    reconstructY(size, sp_num, thermo_ptr_d.Y, Y_origin, real_index, real_num);

    switch(op) {
        case CPU:
        CUDACHECK(hipMemcpyAsync(Y, Y_origin, size*sp_num*sizeof(REAL), hipMemcpyDeviceToHost, Stream_opencc[0]));
        break;

        case GPU:
        cuda_copy(Y_origin, Y, sp_num, sp_num, size, block_set_J0);
        break;

        default:
        MPI_PRINTF("\033[31mWORRY!!! WHEN SETTING DATA_MODE, DATA_MODE SETTING ERROR.\033[0m\n");
        break;
    }
}

/** Query the thermo_ptr_d.
 *  
 *  \param[in] sp_num        The number of species. (host)
 *  \param[in] size          Number of grid points. (host)
 *  \param[in] thermo_ptr_ref  The memory address of the host side corresponding to 
 *                           the packaged storage thermo parameters. (host)
 *  \param[in] op            Data processing mode, see THERMO_SET_MODE for more. (host) 
 */
void thermoFluid_d::thermoFluid_d_get(thermo_ptr *thermo_ptr_ref, THERMO_SET_MODE op) {

    size_t sp_num = thermo_ptr_d.sp_num;
    size_t size   = thermo_ptr_d.size;

    if (op == T_MODE          || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.T,          thermo_ptr_ref->T,          size*sizeof(REAL));
    if (op == P_MODE          || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.P,          thermo_ptr_ref->P,          size*sizeof(REAL));
    if (op == Y_MODE          || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.Y,        thermo_ptr_ref->Y,            size*sizeof(REAL)*sp_num);
    if (op == X_MODE          || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.X,          thermo_ptr_ref->X,          size*sizeof(REAL)*sp_num);

    if (op == rho_MODE        || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.rho,        thermo_ptr_ref->rho,        size*sizeof(REAL));
    if (op == W_mix_MODE      || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.W_mix,      thermo_ptr_ref->W_mix,      size*sizeof(REAL));
    if (op == ha_mass_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.ha_mass,    thermo_ptr_ref->ha_mass,    size*sizeof(REAL));
    if (op == ha_mole_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.ha_mole,    thermo_ptr_ref->ha_mole,    size*sizeof(REAL));
    if (op == hc_mass_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.hc_mass,    thermo_ptr_ref->hc_mass,    sp_num*sizeof(REAL));
    if (op == hc_mole_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.hc_mole,    thermo_ptr_ref->hc_mole,    size*sizeof(REAL));
    if (op == hs_mass_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.hs_mass,    thermo_ptr_ref->hs_mass,    size*sizeof(REAL));
    if (op == hs_mole_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.hs_mole,    thermo_ptr_ref->hs_mole,    size*sizeof(REAL));
    if (op == cp_mole_MODE    || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.cp_mole,    thermo_ptr_ref->cp_mole,    size*sizeof(REAL));
    if (op == psi_MODE        || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.psi,        thermo_ptr_ref->psi,        size*sizeof(REAL));

    if (op == c_MODE          || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.c,          thermo_ptr_ref->c,          size*sizeof(REAL)*(sp_num+2));
    if (op == sp_ha_mole_MODE || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.sp_ha_mole, thermo_ptr_ref->sp_ha_mole, size*sizeof(REAL)*sp_num);
    if (op == sp_cp_mole_MODE || op == THERMO_ALL)
        DeviceDataget(thermo_ptr_d.sp_cp_mole, thermo_ptr_ref->sp_cp_mole, size*sizeof(REAL)*sp_num);

}

/** Query the thermo_ptr_d.
 *  \param[in]  op            Data processing mode, see UPDATE_FRAC_MODE for more. (host) 
 */
void thermoFluid_d::thermoFluid_update_fraction(UPDATE_FRAC_MODE op) {

    size_t size = thermo_ptr_d.size;

    switch(op) {
        case Y_TO_X:
        formYupdateX_h(size, species_d_ptr->species_const_d.sp_W, thermo_ptr_d.W_mix, thermo_ptr_d.X, thermo_ptr_d.Y);
        break;

        case X_TO_Y:
        formXupdateY_h(size, species_d_ptr->species_const_d.sp_W, thermo_ptr_d.W_mix, thermo_ptr_d.X, thermo_ptr_d.Y);
        break;

        default:
        MPI_PRINTF("\033[31mWORRY!!! WHEN SETTING UPDATE FRACTION, UPDATE_FRAC_MODE SETTING ERROR.\033[0m\n");
        break;
    }
}

void thermoFluid_d::thermoFluid_compute_ha_cp() {
    
    compute_ha_cp_h(thermo_ptr_d.T,
    species_d_ptr->species_const_d.T_range,
    species_d_ptr->species_const_d.sp_nasa,
    thermo_ptr_d.sp_ha_mole,
    thermo_ptr_d.sp_cp_mole
    );
}

void thermoFluid_d::thermoFluid_compute_hc_mass() {
    
    compute_hc_mass_h(
    species_d_ptr->species_const_d.sp_nasa,
    species_d_ptr->species_const_d.sp_W,
    thermo_ptr_d.hc_mass
    );
}

void thermoFluid_d::thermoFluid_get_rho_c() {

    get_rho_h(thermo_ptr_d.P,
              thermo_ptr_d.T,
              thermo_ptr_d.W_mix,
              thermo_ptr_d.rho);

    get_c_h(thermo_ptr_d.rho,
            thermo_ptr_d.Y,
            species_d_ptr->species_const_d.sp_W,
            thermo_ptr_d.c);
}

void thermoFluid_d::thermoFluid_get_psi() {

    get_psi_h(thermo_ptr_d.rho,
              thermo_ptr_d.psi);
}

void thermoFluid_d::thermoFluid_set_state_TPY() {

    thermoFluid_update_fraction(Y_TO_X);

    thermoFluid_get_rho_c();

    thermoFluid_compute_ha_cp();

    compute_ha_mass_h(thermo_ptr_d.sp_ha_mole,
        species_d_ptr->species_const_d.sp_W,
        thermo_ptr_d.Y,
        thermo_ptr_d.ha_mass);

    thermoFluid_get_psi();
}

void thermoFluid_d::thermoFluid_update_c_from_TP() {

    update_c_from_T_P_h(thermoFluid_d_ptr->thermo_ptr_d.c, thermoFluid_d_ptr->thermo_ptr_d.T,
            thermoFluid_d_ptr->thermo_ptr_d.P);
            
}

void thermoFluid_d::thermoFluid_update_TP_from_c() {

    update_T_P_from_c_h(thermoFluid_d_ptr->thermo_ptr_d.c, thermoFluid_d_ptr->thermo_ptr_d.T,
            thermoFluid_d_ptr->thermo_ptr_d.P);
            
}

void thermoFluid_d::thermoFluid_h_constraint() {

    h_constraint_h(thermo_ptr_d.X, thermo_ptr_d.c);
}

void thermoFluid_d::thermoFluid_c_constraint() {

    c_constraint_h(thermo_ptr_d.c);
}

void thermoFluid_d::thermoFluid_set_state_hPY() {

    thermoFluid_update_fraction(Y_TO_X);

    enthalpy_to_temperature_h(thermo_ptr_d.c, 
    species_d_ptr->species_const_d.T_range,
    species_d_ptr->species_const_d.sp_nasa,
    thermo_ptr_d.Y,
    species_d_ptr->species_const_d.sp_W,
    thermo_ptr_d.ha_mass,
    thermo_ptr_d.T);

    thermoFluid_get_rho_c();

    thermoFluid_compute_ha_cp();

    compute_ha_mass_h(thermo_ptr_d.sp_ha_mole,
        species_d_ptr->species_const_d.sp_W,
        thermo_ptr_d.Y,
        thermo_ptr_d.ha_mass);

    thermoFluid_get_psi();
}