/*******************************************************************************************
 * This file contains the parameters and function implementation related to seulex
 ******************************************************************************************/
#include <math.h>

#include "common.h"
#include "thermoFluid.h"
#include "reactions.h"
#include "species.h"

#include "seulex.h"
#include "seulex_kernel.h"

/** Global variables for seulex_d class types.
 */
seulex_d *seulex_d_ptr;

seulex_d::seulex_d() {

}

seulex_d::~seulex_d() {

    CUDACHECK(hipFree(dt_c_d));
    CUDACHECK(hipFree(dt_sum_d));

    CUDACHECK(hipFree(scale_d));
    CUDACHECK(hipFree(y_temp_d));
    CUDACHECK(hipFree(y_d));
    CUDACHECK(hipFree(y_seq_d));
    CUDACHECK(hipFree(dydx_d));

    CUDACHECK(hipFree(tmp_ptr_d));

}

void seulex_d::seulex_init() {
    size_t size = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;
    size_t n_vars = sp_num + 2;
    size_t pitch;

    REAL ratio;

    for (int k = 0; k < k_max+1; k++) {
        for (int l = 0; l < k_max+1; l++) {
            coeff_h[l + k*(k_max+1)] = 0.;
        }
    }

    for (int k = 0; k < k_max+1; k++) {
        for (int l = 0; l < k; l++) {
            ratio = (REAL(n_seq_h[k])/n_seq_h[l] - 1.);
            coeff_h[l + k*(k_max+1)] = 1./ratio;
        }
    }


    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(n_seq_d)), &n_seq_h, sizeof(n_seq_h), 0, hipMemcpyHostToDevice));
    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(coeff_d)), &coeff_h, sizeof(coeff_h), 0, hipMemcpyHostToDevice));


    flag_end = (REAL *)malloc(sizeof(REAL));
    CUDACHECK(hipMalloc((void**)&flag_end_d, sizeof(REAL)));
    dt_min = (REAL *)malloc(sizeof(REAL));
    CUDACHECK(hipMalloc((void**)&dt_min_d, sizeof(REAL)));
    dt_sum_min = (REAL *)malloc(sizeof(REAL));
    CUDACHECK(hipMalloc((void**)&dt_sum_min_d, sizeof(REAL)));

    CUDACHECK(hipMalloc((void**)&flag_d, sizeof(REAL)));

    CUDACHECK(hipMalloc((void**)&dt_c_d, sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&dt_c_new_d, sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&denom_d,  sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&flag_reject_d,  sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&unsuccess_d,  sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&success_d,  sizeof(REAL)*size));

    cuda_mem_value_init(dt_c, dt_c_d, 1, 1, size, block_set_J0);
    cuda_mem_value_init(dt_c, dt_c_new_d, 1, 1, size, block_set_J0);
    cuda_mem_value_init(0.0, denom_d,  1, 1, size, block_set_J0);

    CUDACHECK(hipMalloc((void**)&scale_d, (sp_num+2)*sizeof(REAL)*size));

    cuda_mem_value_init(0.0, scale_d, (sp_num+2), (sp_num+2), size, block_set_J0);
    A_h = (REAL*)malloc((sp_num+2)*size*sizeof(REAL));
    
    tmp_ptr_h = (REAL*)malloc((2*(sp_num+2)+1)*(sp_num+2)*size*sizeof(REAL));

    CUDACHECK(hipMalloc((void**)&tmp_ptr_d, (2*(sp_num+2)+1)*(sp_num+2)*sizeof(REAL)*size));
    
    pitch_seulex_tmp_ptr = pitch;

    cuda_mem_value_init(0.0, tmp_ptr_d, (2*(sp_num+2)+1)*(sp_num+2), (2*(sp_num+2)+1)*(sp_num+2), size, block_set_J0);

    CUDACHECK(hipMalloc((void**)&y_temp_d, (sp_num+2)*sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&y_d, (sp_num+2)*sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&y_seq_d, (sp_num+2)*sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&dydx_d, (sp_num+2)*sizeof(REAL)*size));
    CUDACHECK(hipMalloc((void**)&c_old_d, (sp_num+2)*sizeof(REAL)*size));

    pitch_n_vars = pitch;

    cuda_mem_value_init(0.0, y_temp_d, (sp_num+2), (sp_num+2), size, block_set_J0);
    cuda_mem_value_init(0.0, y_d,      (sp_num+2), (sp_num+2), size, block_set_J0);
    cuda_mem_value_init(0.0, y_seq_d,  (sp_num+2), (sp_num+2), size, block_set_J0);
    cuda_mem_value_init(0.0, dydx_d,   (sp_num+2), (sp_num+2), size, block_set_J0);
    cuda_mem_value_init(0.0, c_old_d,  (sp_num+2), (sp_num+2), size, block_set_J0);

    CUDACHECK(hipMalloc((void**)&table_d, (sp_num+2)*k_max*sizeof(REAL)*size));
    
    pitch_table = pitch;

    cuda_mem_value_init(0.0, table_d, (sp_num+2)*k_max, (sp_num+2)*k_max, size, block_set_J0);

    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(pitch_n_vars_d)), &n_vars, sizeof(size_t), 0, hipMemcpyHostToDevice));

    pitch = (2*(sp_num+2)+1)*(sp_num+2);

    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(pitch_seulex_tmp_ptr_d)), &pitch, sizeof(size_t), 0, hipMemcpyHostToDevice));
    
    pitch = (sp_num+2)*k_max;
    
    CUDACHECK(hipMemcpyToSymbol(HIP_SYMBOL(HIP_SYMBOL(pitch_table_d)), &pitch, sizeof(size_t), 0, hipMemcpyHostToDevice));
}

/** Copy data of J form device to host.
 */
void seulex_d::seulex_d_set_J(REAL *J_ref) {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    DeviceDataset(reactions_d_ptr->J_h, J_ref, size*sizeof(REAL)*(sp_num+2)*(sp_num+2));
}

/** Copy data of J form device to host.
 */
void seulex_d::seulex_d_set_dy(REAL *dy_ref) {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    DeviceDataset(y_temp_h, dy_ref, size*sizeof(REAL)*(sp_num+2));
}

/** Copy data of J form device to host.
 */
void seulex_d::seulex_d_get_J() {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    REAL *J_h = reactions_d_ptr->J_h;

    DeviceDataget(reactions_d_ptr->J_d, J_h, size*(sp_num+2)*(sp_num+2)*sizeof(REAL));

#ifdef DEBUG    
    MPI_PRINTF("\n");
    for(int i = 0; i < (sp_num+2); i++){
        for(int j = 0; j < (sp_num+2); j++){
            MPI_PRINTF("J[%d][%d]: %e\t", i, j, *(J_h+j+(sp_num+2)*i));
        }
        MPI_PRINTF("\n");
    }
#endif //DEBUG
}

/** Copy data of J form device to host.
 */
void seulex_d::seulex_d_get_A() {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    DeviceDataget(reactions_d_ptr->reactions_ptr_d.dcdt, A_h, size*(sp_num+2)*sizeof(REAL));

#ifdef DEBUG    
    MPI_PRINTF("\n");
    for(int i = 0; i < (sp_num+2); i++){
        for(int j = 0; j < (sp_num+2); j++){
            MPI_PRINTF("A[%d][%d]: %e\t", i, j, *(A_h+j+(sp_num+2)*i));
        }
        MPI_PRINTF("\n");
    }
#endif //DEBUG
}

/** Copy data of J form device to host.
 */
void seulex_d::seulex_d_get_dy() {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    DeviceDataget(y_temp_d, y_temp_h, size*(sp_num+2)*sizeof(REAL));

#ifdef DEBUG    
    MPI_PRINTF("\n");
    for(int i = 0; i < (sp_num+2); i++){
        MPI_PRINTF("dy[%d]: %e\t", i, *(y_temp_h+i));
        MPI_PRINTF("\n");
    }
#endif //DEBUG
}


void seulex_d::compute_scale() {
    compute_scale0_h(thermoFluid_d_ptr->thermo_ptr_d.c,
                    scale_d);
}

/** form J to compute A.
 */
void seulex_d::compute_A() {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    int k = 0;

    compute_A_h(tmp_ptr_d, reactions_d_ptr->J_d, 
                dt_c_d, k);

    //hipDeviceSynchronize();
}

/** LU decomposition for solving linear equations.
 */
void seulex_d::PLU_decomposition() {
    
    initP_h(tmp_ptr_d);

    PLU_decomposition_h(reactions_d_ptr->reactions_ptr_d.dcdt, tmp_ptr_d);

    //hipDeviceSynchronize();
}

/** Copy data of J form device to host.
 */
void seulex_d::seulex_d_get_PLU() {
    size_t size      = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num    = species_d_ptr->species_const_d.sp_num;

    DeviceDataget(tmp_ptr_d, tmp_ptr_h, size*(2*(sp_num+2)+1)*(sp_num+2)*sizeof(REAL));

#ifdef DEBUG
    for(int s = 0; s < size; s++){    
    MPI_PRINTF("\n");
    MPI_PRINTF("P is:\n");
    for(int i = 0; i < (sp_num+2); i++){
        MPI_PRINTF("%e  \t", *(tmp_ptr_h+i+(2*(sp_num+2)+1)*(sp_num+2)*s));
        MPI_PRINTF("\n");
    }
    MPI_PRINTF("\n");
    MPI_PRINTF("L is:\n");
    for(int j = 0; j < (sp_num+2); j++){
        for(int i = 0; i < (sp_num+2); i++){
            MPI_PRINTF("%e  \t", *(tmp_ptr_h+i+j*(sp_num+2)+(sp_num+2)+(2*(sp_num+2)+1)*(sp_num+2)*s));
        }
        MPI_PRINTF("\n");
    }
    MPI_PRINTF("\n");
    MPI_PRINTF("U is:\n");
    for(int j = 0; j < (sp_num+2); j++){
        for(int i = 0; i < (sp_num+2); i++){
            MPI_PRINTF("%e  \t", *(tmp_ptr_h+i+j*(sp_num+2)+(sp_num+2)+(sp_num+2)*(sp_num+2)+(2*(sp_num+2)+1)*(sp_num+2)*s));
        }
        MPI_PRINTF("\n");
    }
    }
#endif //DEBUG
}

void seulex_d::get_y_seq(REAL *y_seq_h) {
    size_t size   = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    DeviceDataget(y_seq_d, y_seq_h, size*(sp_num+2)*sizeof(REAL));
}

void seulex_d::solve_linear_system() {

    solve_linear_system_h(tmp_ptr_d, reactions_d_ptr->reactions_ptr_d.dcdt);

    //hipDeviceSynchronize();
}

REAL seulex_d::seul(int k) {

    size_t size   = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;
    REAL success = 0.;

    cuda_mem_value_init(0.0, success_d,  1, 1, size, block_set_J0);

    compute_A_h(tmp_ptr_d, reactions_d_ptr->J_d, dt_c_d, k);

#ifdef PERF
	my_timer_opencc tm;
    hipDeviceSynchronize();
	tm.start();
#endif //PERF

    reactions_d_ptr->cal_react_rate(thermoFluid_d_ptr->thermo_ptr_d.c, reactions_d_ptr->reactions_ptr_d.dcdt);

#ifdef PERF
    hipDeviceSynchronize();
	tm.stop();
    MPI_PRINTF("\ncal_react_rate Timecost of CPU (ODE) = %lf s (SEUL)\n", tm.time_use);
	tm.start();
#endif //PERF

    PLU_decomposition();

#ifdef PERF
    hipDeviceSynchronize();
	tm.stop();
    MPI_PRINTF("\nPLU_decomposition Timecost of CPU (ODE) = %lf s (SEUL)\n", tm.time_use);
	tm.start();
#endif //PERF

    solve_linear_system();

#ifdef PERF
    hipDeviceSynchronize();
	tm.stop();
    MPI_PRINTF("\nsolve_linear_system Timecost of CPU (ODE) = %lf s (SEUL)\n", tm.time_use);
	tm.start();
#endif //PERF

    init_y_temp_h(y_temp_d, thermoFluid_d_ptr->thermo_ptr_d.c);
    
    for (int nn = 1; nn < n_seq_h[k]; nn++) {
        update_y_temp_h(y_temp_d, reactions_d_ptr->reactions_ptr_d.dcdt);

        reactions_d_ptr->cal_react_rate(y_temp_d, dydx_d);

        cuda_copy(dydx_d, y_seq_d, (sp_num+2), (sp_num+2), size, block_set_J0);

        if (nn == 1 && k <= 1) {
            compute_dy1_h(reactions_d_ptr->reactions_ptr_d.dcdt, scale_d, denom_d);

            update_dcdt_h(dydx_d, dt_c_d, k, reactions_d_ptr->reactions_ptr_d.dcdt);

            solve_linear_system_h(tmp_ptr_d, reactions_d_ptr->reactions_ptr_d.dcdt);

            compute_flag_h(denom_d, reactions_d_ptr->reactions_ptr_d.dcdt, scale_d, success_d); //success or false  

            cuda_max(success_d, flag_d, 1, 1, size, block_set_J0);
            CUDACHECK(hipMemcpy(&success, flag_d, sizeof(REAL), hipMemcpyDeviceToHost));

            if (success != 0.) {
                cuda_copy(thermoFluid_d_ptr->thermo_ptr_d.c, y_seq_d, (sp_num+2), (sp_num+2), size, block_set_J0);

                return 1.;
            }
        }

        cuda_copy(y_seq_d, reactions_d_ptr->reactions_ptr_d.dcdt, (sp_num+2), (sp_num+2), size, block_set_J0);

        solve_linear_system_h(tmp_ptr_d, reactions_d_ptr->reactions_ptr_d.dcdt);
    }

    update_c_h(y_seq_d, y_temp_d, reactions_d_ptr->reactions_ptr_d.dcdt);

#ifdef PERF
    hipDeviceSynchronize();
	tm.stop();
    MPI_PRINTF("\nSelu Timecost of CPU (ODE) = %lf s (SEUL)\n", tm.time_use);
#endif //PERF

    return 0.;
}

void seulex_d::seulex_solver(REAL t_end, REAL *flag) {

    t_end_h = t_end;

    size_t size   = thermoFluid_d_ptr->thermo_ptr_d.size;
    size_t sp_num = species_d_ptr->species_const_d.sp_num;

    
    REAL log_tol = -log10(rtol + atol)*0.6 + 0.5;
    int k_targ = max(2, min(k_max - 1, int(log_tol)));

    compute_scale();

    int flag_first, flag_loop_num = 0, k, reject, total_reject; flag_first = 1; reject = 0; total_reject = 1;

    REAL result = 0., success = 0.; 

    cuda_mem_value_init(0.0, flag_reject_d,  1, 1, size, block_set_J0);
    cuda_mem_value_init(0.0, unsuccess_d,  1, 1, size, block_set_J0);

    cuda_copy(thermoFluid_d_ptr->thermo_ptr_d.c, y_d, (sp_num+2), (sp_num+2), size, block_set_J0);
my_timer_opencc tm;
    do {
        flag_first = 0; k = 0; reject = 1;

        cuda_mem_value_init(0.0, flag_reject_d,  1, 1, size, block_set_J0);
hipDeviceSynchronize();tm.start();
        reactions_d_ptr->jacobian(y_d);
hipDeviceSynchronize();tm.stop();MPI_PRINTF("\nJacobian calculation time cost = %lf s \n", tm.time_use);
        result = 0., success = 0.; 
			
        do {hipDeviceSynchronize();tm.start();

            success = seul(k);
hipDeviceSynchronize();tm.stop();MPI_PRINTF("\nseul(k) calculation time cost = %lf s \n", tm.time_use);
            if (success != 0.) {

                update_dx_modify_h(0.5, dt_sum_d, t_end, dt_c_d, dt_c_d, success_d, unsuccess_d);

                reject = 0.;

                total_reject = 0.;

                break;
            }

            if (k == 0) {

                cuda_copy(y_seq_d, y_d, (sp_num+2), (sp_num+2), size, block_set_J0);

            } else {

                compute_table_h(k - 1, y_seq_d, table_d);

                extrapolate_h(k, table_d, y_d);

                err_compute_h(y_d, table_d, scale_d, flag_reject_d);

                cuda_max(flag_reject_d, flag_d, 1, 1, size, block_set_J0);
                CUDACHECK(hipMemcpy(&result, flag_d, sizeof(REAL), hipMemcpyDeviceToHost));

                if (result == 0.) {
                    break;
                }
                
                total_reject = 0.;
            }

            k += 1;

        } while (k <= k_targ + 1);

        flag_loop_num += 1;

        if (flag_loop_num > 100) {
            MPI_PRINTF("\nWrong !!! Unable to jump out of loop! Please check the settings!\n");
            thermo_ptr thermo_ptr_h;

            cuda_copy(y_d, thermoFluid_d_ptr->thermo_ptr_d.c, (sp_num+2), (sp_num+2), size, block_set_J0);

            get_thermoFluid(&thermo_ptr_h, THERMO_ALL);

            exit(0);
        }

    } while (flag_first == 1 || reject == 0.);

    cuda_copy(y_d, thermoFluid_d_ptr->thermo_ptr_d.c, (sp_num+2), (sp_num+2), size, block_set_J0);

    update_dt_sum_h(dt_c_d, dt_c_new_d, t_end, dt_sum_d);

    cuda_min(dt_sum_d, dt_sum_min_d, 1, 1, size, block_set_J0);

    CUDACHECK(hipMemcpy(dt_sum_min, dt_sum_min_d, sizeof(REAL), hipMemcpyDeviceToHost));

    update_dx_h(2, dt_sum_d, t_end, dt_c_d, dt_c_d, unsuccess_d);

    *flag = 0.; if (*dt_sum_min < t_end) *flag = 1.;
}

