/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2013-2015 OpenFOAM Foundation
    Copyright (C) 2017-2019 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "GGAMGSolver.H"

// * * * * * * * * * * * * * Private Member Functions  * * * * * * * * * * * //

namespace Foam
{

struct GAMGinterpolatePsiFunctor
{
    __host__ __device__
    scalar operator()(const scalar& Apsi, const scalar& diag)
    {
        return - Apsi/diag;
    }
};

}

void Foam::GGAMGSolver::interpolate
(
    scalargpuField& psi,
    scalargpuField& Apsi,
    const gpulduMatrix& m,
    const FieldField<gpuField, scalar>& interfaceBouCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces,
    const direction cmpt
) const
{
    const labelgpuList& l = m.lduAddr().gpuOwnerSortAddr();
    const labelgpuList& u = m.lduAddr().upperAddr();

    const scalargpuField& Lower = m.gpuLowerSort();
    const scalargpuField& Upper = m.gpuUpper();
    const scalargpuField& Diag = m.gpuDiag();
  
    const label startRequest = Pstream::nRequests();

    m.initMatrixInterfaces
    (
        true,
        interfaceBouCoeffs,
        interfaces,
        psi,
        Apsi,
        cmpt
    );

    matrixFastOperation
    (
        thrust::make_constant_iterator(scalar(0.0)),
        Apsi,
        m.lduAddr(),
        matrixCoeffsMultiplyFunctor<scalar,scalar,thrust::identity<scalar> >
        (
            psi.data(),
            Upper.data(),
            u.data(),
            thrust::identity<scalar>()
        ),
        matrixCoeffsMultiplyFunctor<scalar,scalar,thrust::identity<scalar> >
        (
            psi.data(),
            Lower.data(),
            l.data(),
            thrust::identity<scalar>()
        )
    );

    m.updateMatrixInterfaces
    (
        true,
        interfaceBouCoeffs,
        interfaces,
        psi,
        Apsi,
        cmpt,
        startRequest
    );
    
    thrust::transform
    (
        Apsi.begin(),
        Apsi.end(),
        Diag.begin(),
        psi.begin(),
        GAMGinterpolatePsiFunctor()
    );
}

namespace Foam
{

struct GAMGinterpolateCorrCdiagCFunctor
{
    const scalar* diag;
    const scalar* psi;
    const label* sort;

    GAMGinterpolateCorrCdiagCFunctor
    (
        const scalar* _diag,
        const scalar* _psi,
        const label* _sort
    ):
        diag(_diag),
        psi(_psi),
        sort(_sort)
    {}

    __host__ __device__
    thrust::tuple<scalar,scalar> operator()
    (
        const label& start, 
        const label& end
    )
    {
        scalar corrC = 0;
        scalar diagC = 0;

        for(label i = start; i<end; i++)
        {
            label celli = sort[i];

            corrC += diag[celli]*psi[celli];
            diagC += diag[celli];
        }

        return thrust::make_tuple(corrC,diagC);
    }

};

struct GAMGinterpolateCorrCFunctor
{
    template<class Tuple>
    __host__ __device__
    scalar operator()(const scalar& psiC, const Tuple& t)
    {
        return psiC - thrust::get<0>(t)/thrust::get<1>(t);
    }
}; 
   
}

void Foam::GGAMGSolver::interpolate
(
    scalargpuField& psi,
    scalargpuField& Apsi,
    const gpulduMatrix& m,
    const FieldField<gpuField, scalar>& interfaceBouCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces,
    const labelgpuList& restrictSortAddressing,
    const labelgpuList& restrictTargetAddressing,
    const labelgpuList& restrictTargetStartAddressing,
    const scalargpuField& psiC,
    const direction cmpt
) const
{
    interpolate
    (
        psi,
        Apsi,
        m,
        interfaceBouCoeffs,
        interfaces,
        cmpt
    );

    const label nCCells = psiC.size();
    scalargpuField corrC(nCCells, 0);

    scalargpuField diagC(nCCells, 0);
 
    thrust::transform
    (
        restrictTargetStartAddressing.begin(),
        restrictTargetStartAddressing.end()-1,
        restrictTargetStartAddressing.begin()+1,
        thrust::make_zip_iterator(thrust::make_tuple
        (
            thrust::make_permutation_iterator
            (
                corrC.begin(),
                restrictTargetAddressing.begin()
            ),
            thrust::make_permutation_iterator
            (
                diagC.begin(),
                restrictTargetAddressing.begin()
            )
        )),
        GAMGinterpolateCorrCdiagCFunctor
        (
            m.gpuDiag().data(),
            psi.data(),
            restrictSortAddressing.data()
        )
    );

    thrust::transform
    (
        psiC.begin(),
        psiC.end(),
        thrust::make_zip_iterator(thrust::make_tuple
        (
            corrC.begin(),
            diagC.begin()
        )),
        corrC.begin(),
        GAMGinterpolateCorrCFunctor()
    );

}


// ************************************************************************* //
