/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2016 OpenFOAM Foundation
    Copyright (C) 2015-2017 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "hePsiThermo.H"

namespace Foam
{
	template<class MixtureType>
	struct hePsiThermoCalculateFunctor{
		const bool updateT;
		hePsiThermoCalculateFunctor(const bool _updateT): 
			updateT(_updateT){}
		__host__ __device__
		thrust::tuple<scalar,scalar,scalar,scalar>
		operator ()(const MixtureType& mixture, const thrust::tuple<scalar,scalar,scalar>& t){
            scalar h = thrust::get<0>(t);
			scalar p = thrust::get<1>(t);
			scalar T = thrust::get<2>(t);

			if(updateT){
                T = mixture.THE(h,p,T);
            }
			
			return thrust::make_tuple(T,
			                          mixture.psi(p,T),
			                          mixture.mu(p,T),
			                          mixture.alphah(p,T)
			                         );
		}
	};

	template<class MixtureType>
	struct hePsiThermoHECalculateFunctor{
		__host__ __device__
		thrust::tuple<scalar,scalar,scalar,scalar>
		operator ()(const MixtureType& mixture, const thrust::tuple<scalar,scalar>& t){
			scalar p = thrust::get<0>(t);
			scalar T = thrust::get<1>(t);
			
			return thrust::make_tuple(mixture.HE(p,T),
			                          mixture.psi(p,T),
			                          mixture.mu(p,T),
			                          mixture.alphah(p,T)
			                         );
		}
	};
}

// * * * * * * * * * * * * * Private Member Functions  * * * * * * * * * * * //

template<class BasicPsiThermo, class MixtureType>
void Foam::hePsiThermo<BasicPsiThermo, MixtureType>::calculate
(
    const volScalargpuField& p,
    volScalargpuField& T,
    volScalargpuField& he,
    volScalargpuField& psi,
    volScalargpuField& mu,
    volScalargpuField& alpha,
    const bool doOldTimes
)
{
    // Note: update oldTimes before current time so that if T.oldTime() is
    // created from T, it starts from the unconverted T
    if (doOldTimes && (p.nOldTimes() || T.nOldTimes()))
    {
        calculate
        (
            p.oldTime(),
            T.oldTime(),
            he.oldTime(),
            psi.oldTime(),
            mu.oldTime(),
            alpha.oldTime(),
            true
        );
    }

    const scalargpuField& hCells = he.primitiveField();
    const scalargpuField& pCells = p.primitiveField();

    scalargpuField& TCells = T.primitiveFieldRef();
    scalargpuField& psiCells = psi.primitiveFieldRef();
    scalargpuField& muCells = mu.primitiveFieldRef();
    scalargpuField& alphaCells = alpha.primitiveFieldRef();

/*    forAll(TCells, celli)
    {
        const typename MixtureType::thermoType& mixture_ =
            this->cellMixture(celli);

        if (this->updateT())
        {
            TCells[celli] = mixture_.THE
            (
                hCells[celli],
                pCells[celli],
                TCells[celli]
            );
        }

        psiCells[celli] = mixture_.psi(pCells[celli], TCells[celli]);

        muCells[celli] = mixture_.mu(pCells[celli], TCells[celli]);
        alphaCells[celli] = mixture_.alphah(pCells[celli], TCells[celli]);
    }*/

    this->updateCellMixture();
    const thrust::device_ptr<const typename MixtureType::thermoType> mixtureCellsPtr = this->mixtureCells();

    thrust::transform(mixtureCellsPtr,mixtureCellsPtr + hCells.size(),
                      thrust::make_zip_iterator(thrust::make_tuple( hCells.begin(),
                                                                    pCells.begin(),
                                                                    TCells.begin()
                                                                    )),
                      thrust::make_zip_iterator(thrust::make_tuple(TCells.begin(),
                                                                   psiCells.begin(),
                                                                   muCells.begin(),
                                                                   alphaCells.begin()
                                                                   )),
                      hePsiThermoCalculateFunctor<typename MixtureType::thermoType>(this->updateT()));

    const volScalargpuField::Boundary& pBf = p.boundaryField();
    volScalargpuField::Boundary& TBf = T.boundaryFieldRef();
    volScalargpuField::Boundary& psiBf = psi.boundaryFieldRef();
    volScalargpuField::Boundary& heBf = he.boundaryFieldRef();
    volScalargpuField::Boundary& muBf = mu.boundaryFieldRef();
    volScalargpuField::Boundary& alphaBf = alpha.boundaryFieldRef();

    forAll(pBf, patchi)
    {
        const fvPatchScalargpuField& pp = pBf[patchi];
        fvPatchScalargpuField& pT = TBf[patchi];
        fvPatchScalargpuField& ppsi = psiBf[patchi];
        fvPatchScalargpuField& phe = heBf[patchi];
        fvPatchScalargpuField& pmu = muBf[patchi];
        fvPatchScalargpuField& palpha = alphaBf[patchi];

        if (pT.fixesValue())
        {
/*            forAll(pT, facei)
            {
                const typename MixtureType::thermoType& mixture_ =
                    this->patchFaceMixture(patchi, facei);

                phe[facei] = mixture_.HE(pp[facei], pT[facei]);

                ppsi[facei] = mixture_.psi(pp[facei], pT[facei]);
                pmu[facei] = mixture_.mu(pp[facei], pT[facei]);
                palpha[facei] = mixture_.alphah(pp[facei], pT[facei]);
            }*/

        const thrust::device_ptr<const typename MixtureType::thermoType> mixtureFacesPtr = this->updatePatchFaceMixture(patchi);

        thrust::transform(mixtureFacesPtr,mixtureFacesPtr + pp.size(),
                      thrust::make_zip_iterator(
                              thrust::make_tuple(
                                      pp.begin(),
                                      pT.begin()
                              )),
					  thrust::make_zip_iterator(
                        thrust::make_tuple(phe.begin(),
										   ppsi.begin(),
										   pmu.begin(),
										   palpha.begin()
										)),
					  hePsiThermoHECalculateFunctor<typename MixtureType::thermoType>());
        }
        else
        {
 /*           forAll(pT, facei)
            {
                const typename MixtureType::thermoType& mixture_ =
                    this->patchFaceMixture(patchi, facei);

                if (this->updateT())
                {
                    pT[facei] = mixture_.THE(phe[facei], pp[facei], pT[facei]);
                }

                ppsi[facei] = mixture_.psi(pp[facei], pT[facei]);
                pmu[facei] = mixture_.mu(pp[facei], pT[facei]);
                palpha[facei] = mixture_.alphah(pp[facei], pT[facei]);
            }*/

            const thrust::device_ptr<typename MixtureType::thermoType> mixtureFacesPtr = this->updatePatchFaceMixture(patchi);

            thrust::transform(mixtureFacesPtr,mixtureFacesPtr+pp.size(),
							  thrust::make_zip_iterator(thrust::make_tuple( 
                                                                            phe.begin(),
                                                                            pp.begin(),
																			pT.begin()
                                                                            )),
							  thrust::make_zip_iterator(thrust::make_tuple(pT.begin(),
																		   ppsi.begin(),
																		   pmu.begin(),
																		   palpha.begin()
																		   )),
							  hePsiThermoCalculateFunctor<typename MixtureType::thermoType>(this->updateT()));
        }
    }
}

// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

template<class BasicPsiThermo, class MixtureType>
Foam::hePsiThermo<BasicPsiThermo, MixtureType>::hePsiThermo
(
    const gpufvMesh& mesh,
    const word& phaseName
)
:
    heThermo<BasicPsiThermo, MixtureType>(mesh, phaseName)
{
    calculate
    (
        this->p_,
        this->T_,
        this->he_,
        this->psi_,
        this->mu_,
        this->alpha_,
        true                    // Create old time fields
    );
}


template<class BasicPsiThermo, class MixtureType>
Foam::hePsiThermo<BasicPsiThermo, MixtureType>::hePsiThermo
(
    const gpufvMesh& mesh,
    const word& phaseName,
    const word& dictionaryName
)
:
    heThermo<BasicPsiThermo, MixtureType>(mesh, phaseName, dictionaryName)
{
    calculate
    (
        this->p_,
        this->T_,
        this->he_,
        this->psi_,
        this->mu_,
        this->alpha_,
        true                    // Create old time fields
    );
}



// * * * * * * * * * * * * * * * * Destructor  * * * * * * * * * * * * * * * //

template<class BasicPsiThermo, class MixtureType>
Foam::hePsiThermo<BasicPsiThermo, MixtureType>::~hePsiThermo()
{}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

template<class BasicPsiThermo, class MixtureType>
void Foam::hePsiThermo<BasicPsiThermo, MixtureType>::correct()
{
    DebugInFunction << endl;

    calculate
    (
        this->p_,
        this->T_,
        this->he_,
        this->psi_,
        this->mu_,
        this->alpha_,
        false           // No need to update old times
    );

    DebugInFunction << "Finished" << endl;
}

// ************************************************************************* //
