/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "multiComponentMixture.H"
#include <thrust/transform.h>
#include <thrust/gather.h>

// * * * * * * * * * * * * * Private Member Functions  * * * * * * * * * * * //

template<class ThermoType>
const ThermoType& Foam::multiComponentMixture<ThermoType>::constructSpeciesData
(
    const dictionary& thermoDict
)
{
    forAll(species_, i)
    {
        speciesData_.set
        (
            i,
            new ThermoType(thermoDict.subDict(species_[i]))
        );
    }

    return speciesData_[0];
}


template<class ThermoType>
void Foam::multiComponentMixture<ThermoType>::correctMassFractions()
{
    // Multiplication by 1.0 changes Yt patches to "calculated"
    volScalargpuField Yt("Yt", 1.0*Y_[0]);

    for (label n=1; n<Y_.size(); n++)
    {
        Yt += Y_[n];
    }

    if (mag(min(Yt).value()) < ROOTVSMALL)
    {
        FatalErrorInFunction
            << "Sum of mass fractions is zero for species " << this->species()
            << exit(FatalError);
    }

    if (mag(max(Yt).value()) != scalar(1))
    {
        WarningInFunction
            << "Sum of mass fractions is different from one for species "
            << this->species()
            << nl;
    }

    forAll(Y_, n)
    {
        Y_[n] /= Yt;
    }
}


template<class ThermoType>
void Foam::multiComponentMixture<ThermoType>::calcMem(const gpufvMesh& mesh)
{
    if (mixtureCells_ && mixtureVolCells_ && gSpeciesData_)
    {
        FatalErrorInFunction
            << "mixtureCells_ or mixtureVolCells_ allocated"
            << abort(FatalError);
    }

    const label cells = mesh.hostmesh().nCells();
    const label faces = mesh.hostmesh().nFaces(); 

    gSpeciesData_ = thrust::device_malloc<ThermoType>(species_.size());
    thrust::copy(speciesData_.begin(),speciesData_.end(),gSpeciesData_);

    mixtureCells_ = thrust::device_malloc<ThermoType>(cells);
    mixtureVolCells_ = thrust::device_malloc<ThermoType>(cells);

    thrust::fill(mixtureCells_,mixtureCells_+cells,mixture_);
    thrust::fill(mixtureVolCells_,mixtureVolCells_+cells,mixtureVol_);
}


// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

template<class ThermoType>
Foam::multiComponentMixture<ThermoType>::multiComponentMixture
(
    const dictionary& thermoDict,
    const wordList& specieNames,
    const ReactionTable<ThermoType>& thermoData,
    const gpufvMesh& mesh,
    const word& phaseName
)
:
    basicSpecieMixture(thermoDict, specieNames, mesh, phaseName),
    speciesData_(species_.size()),
    mixture_("mixture", *thermoData[specieNames[0]]),
    mixtureVol_("volMixture", *thermoData[specieNames[0]]),
    gSpeciesData_(nullptr),
    mixtureCells_(nullptr),
    mixtureVolCells_(nullptr),
    mixtureFaces_(nullptr),
    mixtureVolFaces_(nullptr)
{
    forAll(species_, i)
    {
        speciesData_.set
        (
            i,
            new ThermoType(*thermoData[species_[i]])
        );
    }

    correctMassFractions();
    calcMem(mesh);
}


template<class ThermoType>
Foam::multiComponentMixture<ThermoType>::multiComponentMixture
(
    const dictionary& thermoDict,
    const gpufvMesh& mesh,
    const word& phaseName
)
:
    basicSpecieMixture
    (
        thermoDict,
        thermoDict.lookup("species"),
        mesh,
        phaseName
    ),
    speciesData_(species_.size()),
    mixture_("mixture", constructSpeciesData(thermoDict)),
    mixtureVol_("volMixture", speciesData_[0]),
    gSpeciesData_(nullptr),
    mixtureCells_(nullptr),
    mixtureVolCells_(nullptr),
    mixtureFaces_(nullptr),
    mixtureVolFaces_(nullptr)
{
    correctMassFractions();
    calcMem(mesh);
}

// * * * * * * * * * * * * * * * * Destructor  * * * * * * * * * * * * * * * //
template<class ThermoType>
Foam::multiComponentMixture<ThermoType>::~multiComponentMixture()
{
    thrust::device_free(gSpeciesData_);
    thrust::device_free(mixtureCells_);
    thrust::device_free(mixtureVolCells_);
    thrust::device_free(mixtureFaces_);
    thrust::device_free(mixtureVolFaces_);
}

// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

template<class ThermoType>
__host__ __device__
const ThermoType& Foam::multiComponentMixture<ThermoType>::cellMixture
(
    const label celli
) const
{
    //return mixtureCells_[celli];
    return mixture_;
}


template<class ThermoType>
__host__ __device__
const ThermoType& Foam::multiComponentMixture<ThermoType>::patchFaceMixture
(
    const label patchi,
    const label facei
) const
{
    //return mixtureFaces_[facei];
    return mixture_;
}


template<class ThermoType>
__host__ __device__
const ThermoType& Foam::multiComponentMixture<ThermoType>::cellVolMixture
(
    const scalar p,
    const scalar T,
    const label celli
) const
{
    //return mixtureVolCells_[celli];
    return mixture_;
}


template<class ThermoType>
__host__ __device__
const ThermoType& Foam::multiComponentMixture<ThermoType>::
patchFaceVolMixture
(
    const scalar p,
    const scalar T,
    const label patchi,
    const label facei
) const
{
    //return mixtureVolFaces_[facei];
    return mixture_;
}


template<class ThermoType>
void Foam::multiComponentMixture<ThermoType>::read
(
    const dictionary& thermoDict
)
{
    forAll(species_, i)
    {
        speciesData_[i] = ThermoType(thermoDict.subDict(species_[i]));
    }
    
    thrust::copy(speciesData_.begin(),speciesData_.end(),gSpeciesData_);
}


namespace Foam
{
template <class ThermoType>
struct mixtureFunctor
{
    scalar *Y;
    ThermoType *speciesData;
    label m;
    label n;

    mixtureFunctor(
    scalar *_Y, 
    ThermoType *_speciesData, 
    label _m, 
    label _n) 
    : 
    Y(_Y), 
    speciesData(_speciesData), 
    m(_m), 
    n(_n) 
    {}

    __host__ __device__
    ThermoType operator()(label idx)
    {
        label row = idx % n;
        ThermoType tmpMixture = Y[row * m + 0] * speciesData[0];

        for (int i = 1; i < m; i++)
        {
            tmpMixture += Y[row * m + i] * speciesData[i];
        }

        return tmpMixture;
    }
};

struct TransposeFunctor
{
    const label rows, cols;

    TransposeFunctor(label _rows, label _cols) : rows(_rows), cols(_cols) {}

    __host__ __device__
    label operator()(const label& idx) const {
        label row = idx / rows;
        label col = idx % rows;
        return col * cols + row;
    }
};


template <class ThermoType>
struct volMixtureFunctor
{
    scalar *Y;
    ThermoType *speciesData_;
    label m;
    label n;
    scalar *p;
    scalar *T;

    volMixtureFunctor(
    scalar *_Y, 
    ThermoType *_speciesData, 
    label _m, 
    label _n,
    scalar *_p,
    scalar *_T) 
    : 
    Y(_Y), 
    speciesData_(_speciesData),
    m(_m), 
    n(_n),
    p(_p),
    T(_T)
    {}

    __host__ __device__
    ThermoType operator()(label idx)
    {
        label row = idx % n;
        scalar rhoInv = 0; 
        ThermoType _mixtureVol = Y[row * m + 0]/speciesData_[0].rho(p[idx], T[idx])/rhoInv*speciesData_[0];
         
        for (int i = 0; i < m; i++)
        {
            rhoInv += Y[row * m + i]/speciesData_[i].rho(p[idx], T[idx]);
            _mixtureVol += Y[row * m + i]/speciesData_[i].rho(p[idx], T[idx])/rhoInv*speciesData_[i];
        }

        return _mixtureVol;
    }
};
}


template<class ThermoType>
void Foam::multiComponentMixture<ThermoType>::updateCellMixture()
{
    label n = Y_.size();
    label cells = Y_[0].size();

    scalargpuField tmpY(n*cells);
    scalargpuField tmpY_(n*cells);

    //将Y_存储格式转化成一个一维数组，按列存储
    forAll(Y_, i)
    {
        thrust::copy(Y_[i].primitiveField().begin(),Y_[i].primitiveField().end(),tmpY.begin()+i*cells);
    }

    gpuList<scalar> index_result(n*cells);
    thrust::transform(
        thrust::counting_iterator<int>(0),
        thrust::counting_iterator<int>(n * cells),
        index_result.begin(),
        TransposeFunctor(n, cells)
    );

    thrust::gather(
        index_result.begin(),index_result.end(),
        tmpY.begin(),
        tmpY_.begin()
    );
 
    thrust::transform(
        thrust::counting_iterator<label>(0),
        thrust::counting_iterator<label>(cells),
        mixtureCells_,
        mixtureFunctor<ThermoType>
        (
            tmpY_.data(),
            raw_pointer_cast(gSpeciesData_),
            n,
            cells)
    );
}


template<class ThermoType>
const thrust::device_ptr<ThermoType>& Foam::multiComponentMixture<ThermoType>::updatePatchFaceMixture
(
    const label patchi
) const
{
    label n = Y_.size();
    label faces = Y_[0].boundaryField()[patchi].size();

    if (mixtureFaces_)
    {
        thrust::device_free(mixtureFaces_);
    }

    mixtureFaces_ = thrust::device_malloc<ThermoType>(faces);
   
    scalargpuField tmpYBoundaryField(n*faces);
    scalargpuField tmpYBoundaryField_(n*faces);

    forAll(Y_, i)
    {
        thrust::copy(
            Y_[i].boundaryField()[patchi].begin(),
            Y_[i].boundaryField()[patchi].end(),
            tmpYBoundaryField.begin()+i*faces);
    }

    gpuList<scalar> index_result(n * faces);
    thrust::transform(
        thrust::counting_iterator<int>(0),
        thrust::counting_iterator<int>(n * faces),
        index_result.begin(),
        TransposeFunctor(n, faces)
    );

    thrust::gather(
        index_result.begin(),index_result.end(),
        tmpYBoundaryField.begin(),
        tmpYBoundaryField_.begin()
    );
    
    thrust::transform(
        thrust::counting_iterator<label>(0),
        thrust::counting_iterator<label>(faces),
        mixtureFaces_,
        mixtureFunctor<ThermoType>
        (
            tmpYBoundaryField_.data(),
            raw_pointer_cast(gSpeciesData_),
            n,
            faces)
    );

    return mixtureFaces_;
}


template<class ThermoType>
void Foam::multiComponentMixture<ThermoType>::updateCellVolMixture
(
    const scalargpuField& pCells,
    const scalargpuField& TCells
)
{
    label n = Y_.size();
    label cells = Y_[0].size();

    scalargpuField tmpY(n*cells);
    scalargpuField tmpY_(n*cells);

    //将Y_存储格式转化成一个一维数组，按列存储
    forAll(Y_, i)
    {
        thrust::copy(Y_[i].primitiveField().begin(),Y_[i].primitiveField().end(),tmpY.begin()+i*cells);
    }

    gpuList<scalar> index_result(n*cells);
    thrust::transform(
        thrust::counting_iterator<int>(0),
        thrust::counting_iterator<int>(n * cells),
        index_result.begin(),
        TransposeFunctor(n, cells)
    );

    thrust::gather(
        index_result.begin(),index_result.end(),
        tmpY.begin(),
        tmpY_.begin()
    );

    thrust::transform(
        thrust::counting_iterator<label>(0),
        thrust::counting_iterator<label>(cells),
        mixtureVolCells_,
        volMixtureFunctor<ThermoType>
        (
            tmpY_.data(),
            raw_pointer_cast(gSpeciesData_),
            n,
            cells,
            pCells.data(),
            TCells.data())
    );

}

template <class ThermoType>
const thrust::device_ptr<ThermoType>& Foam::multiComponentMixture<ThermoType>::updatePatchFaceVolMixture
(
    const scalargpuField &pCells, 
    const scalargpuField &TCells, 
    const label patchi
) const
{
    label n = Y_.size();
    label faces = Y_[0].boundaryField()[patchi].size();

    if (mixtureVolFaces_)
    {
        thrust::device_free(mixtureVolFaces_);
    }  

    mixtureVolFaces_ = thrust::device_malloc<ThermoType>(faces);

    scalargpuField tmpYBoundaryField(n*faces);
    scalargpuField tmpYBoundaryField_(n*faces);

    forAll(Y_, i)
    {
        thrust::copy(
            Y_[i].boundaryField()[patchi].begin(),
            Y_[i].boundaryField()[patchi].end(),
            tmpYBoundaryField.begin()+i*faces);
    }

    gpuList<scalar> index_result(n * faces);
    thrust::transform(
        thrust::counting_iterator<int>(0),
        thrust::counting_iterator<int>(n * faces),
        index_result.begin(),
        TransposeFunctor(n, faces)
    );

    thrust::gather(
        index_result.begin(),index_result.end(),
        tmpYBoundaryField.begin(),
        tmpYBoundaryField_.begin()
    );

    thrust::transform(
        thrust::counting_iterator<label>(0),
        thrust::counting_iterator<label>(faces),
        mixtureVolFaces_,
        volMixtureFunctor<ThermoType>
        (
            tmpYBoundaryField_.data(),
            raw_pointer_cast(gSpeciesData_),
            n,
            faces,
            pCells.data(),
            TCells.data())
    );

    return mixtureVolFaces_;
}
// ************************************************************************* //
