/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2018 OpenFOAM Foundation
    Copyright (C) 2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpucellLimitedGrad.H"
#include "gpugaussGrad.H"

// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //


namespace Foam
{
    template<class gpucellLimitedGrad>
    struct cellLimitedGradcalcGradFunctor
    {
        const gpucellLimitedGrad& cellLimited;

        cellLimitedGradcalcGradFunctor(const gpucellLimitedGrad &cellLimited_)
            : cellLimited(cellLimited_) {}

        template <class Tuple>
        __host__ __device__ void operator()(const Tuple &t)
        {
            cellLimited.limitFace
            (
                thrust::get<0>(t),
                thrust::get<1>(t),
                thrust::get<2>(t),
                (thrust::get<3>(t) - thrust::get<4>(t)) & thrust::get<5>(t)
            );
        }
    };

    struct cellLimitedGradlimitGradientFunctor
    {
        
        __host__ __device__
        tensor operator()(const vector& v, const tensor& t)
        {
            return tensor(
                cmptMultiply(v, t.x()),
                cmptMultiply(v, t.y()),
                cmptMultiply(v, t.z()));

        }
    };
}


template<class Type, class Limiter>
void Foam::fv::gpucellLimitedGrad<Type, Limiter>::limitGradient
(
    const gpuField<scalar>& limiter,
    gpuField<vector>& gIf
) const
{
    gIf *= limiter;
}


template<class Type, class Limiter>
void Foam::fv::gpucellLimitedGrad<Type, Limiter>::limitGradient
(
    const gpuField<vector>& limiter,
    gpuField<tensor>& gIf
) const
{
    thrust::transform(limiter.begin(),
                      limiter.end(),
                      gIf.begin(),
                      gIf.begin(),
                      cellLimitedGradlimitGradientFunctor());
}

template<class Type, class Limiter>
Foam::tmp
<
    Foam::GeometricgpuField
    <
        typename Foam::outerProduct<Foam::vector, Type>::type,
        Foam::fvPatchgpuField,
        Foam::gpuvolMesh
    >
>
Foam::fv::gpucellLimitedGrad<Type, Limiter>::calcGrad
(
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vsf,
    const word& name
) const
{
    const gpufvMesh& mesh = vsf.mesh();

    tmp
    <
        GeometricgpuField
        <typename outerProduct<vector, Type>::type, fvPatchgpuField, gpuvolMesh>
    > tGrad = basicGradScheme_().calcGrad(vsf, name);

    if (k_ < SMALL)
    {
        return tGrad;
    }

    GeometricgpuField
    <
        typename outerProduct<vector, Type>::type,
        fvPatchgpuField,
        gpuvolMesh
    >& g = tGrad.ref();

    const labelgpuList& owner = mesh.owner();
    const labelgpuList& neighbour = mesh.neighbour();

    const volVectorgpuField& C = mesh.C();
    const surfaceVectorgpuField& Cf = mesh.Cf(); 
    
    gpuField<Type> maxVsf(vsf.primitiveField());
    gpuField<Type> minVsf(vsf.primitiveField());


    thrust::transform(thrust::make_permutation_iterator(maxVsf.begin(), owner.begin()),
                      thrust::make_permutation_iterator(maxVsf.begin(), owner.end()),
                      thrust::make_permutation_iterator(vsf.field().begin(), neighbour.begin()),
                      thrust::make_permutation_iterator(maxVsf.begin(), owner.begin()),
                      maxBinaryFunctionFunctor<Type, Type, Type>());

    thrust::transform(thrust::make_permutation_iterator(minVsf.begin(), owner.begin()),
                      thrust::make_permutation_iterator(minVsf.begin(), owner.end()),
                      thrust::make_permutation_iterator(vsf.field().begin(), neighbour.begin()),
                      thrust::make_permutation_iterator(minVsf.begin(), owner.begin()),
                      minBinaryFunctionFunctor<Type, Type, Type>());
    
    thrust::transform(thrust::make_permutation_iterator(maxVsf.begin(), neighbour.begin()),
                      thrust::make_permutation_iterator(maxVsf.begin(), neighbour.end()),
                      thrust::make_permutation_iterator(vsf.field().begin(), owner.begin()),
                      thrust::make_permutation_iterator(maxVsf.begin(), neighbour.begin()),
                      maxBinaryFunctionFunctor<Type, Type, Type>());
    
    thrust::transform(thrust::make_permutation_iterator(minVsf.begin(), neighbour.begin()),
                      thrust::make_permutation_iterator(minVsf.begin(), neighbour.end()),
                      thrust::make_permutation_iterator(vsf.field().begin(), owner.begin()),
                      thrust::make_permutation_iterator(minVsf.begin(), neighbour.begin()),
                      minBinaryFunctionFunctor<Type, Type, Type>());

    const typename GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>::Boundary& bsf =
        vsf.boundaryField();
    
    forAll(bsf, patchi)
    {   
        const fvPatchgpuField<Type>& psf = bsf[patchi];
        const labelgpuList& pOwner = mesh.boundary()[patchi].gpuFaceCells();

        if (psf.coupled())
        {
            const gpuField<Type> psfNei(psf.patchNeighbourField());


            thrust::transform(thrust::make_permutation_iterator(maxVsf.begin(), pOwner.begin()),
                              thrust::make_permutation_iterator(maxVsf.begin(), pOwner.end()),
                              psfNei.begin(),
                              thrust::make_permutation_iterator(maxVsf.begin(), pOwner.begin()),
                              maxBinaryFunctionFunctor<Type, Type, Type>());

            thrust::transform(thrust::make_permutation_iterator(minVsf.begin(), pOwner.begin()),
                              thrust::make_permutation_iterator(minVsf.begin(), pOwner.end()),
                              psfNei.begin(),
                              thrust::make_permutation_iterator(minVsf.begin(), pOwner.begin()),
                              minBinaryFunctionFunctor<Type, Type, Type>());
        }
        else
        {

            thrust::transform(thrust::make_permutation_iterator(maxVsf.begin(), pOwner.begin()),
                              thrust::make_permutation_iterator(maxVsf.begin(), pOwner.end()),
                              psf.begin(),
                              thrust::make_permutation_iterator(maxVsf.begin(), pOwner.begin()),
                              maxBinaryFunctionFunctor<Type, Type, Type>());

            thrust::transform(thrust::make_permutation_iterator(minVsf.begin(), pOwner.begin()),
                              thrust::make_permutation_iterator(minVsf.begin(), pOwner.end()),
                              psf.begin(),
                              thrust::make_permutation_iterator(minVsf.begin(), pOwner.begin()),
                              minBinaryFunctionFunctor<Type, Type, Type>());
        }
    }
    
    maxVsf -= vsf;
    minVsf -= vsf;

    if (k_ < 1.0)
    {
        const gpuField<Type> maxMinVsf((1.0/k_ - 1.0)*(maxVsf - minVsf));
        maxVsf += maxMinVsf;
        minVsf -= maxMinVsf;
    }


    // Create limiter initialized to 1
    // Note: the limiter is not permitted to be > 1
    gpuField<Type> limiter(vsf.primitiveField().size(), pTraits<Type>::one);


    thrust::for_each(thrust::make_zip_iterator(thrust::make_tuple(
                         thrust::make_permutation_iterator(
                             limiter.begin(),
                             owner.begin()),
                         thrust::make_permutation_iterator(
                             maxVsf.begin(),
                             owner.begin()),
                         thrust::make_permutation_iterator(
                             minVsf.begin(),
                             owner.begin()),
                         Cf.field().begin(),
                         thrust::make_permutation_iterator(
                             C.field().begin(),
                             owner.begin()),
                         thrust::make_permutation_iterator(
                             g.field().begin(),
                             owner.begin()))),
                     thrust::make_zip_iterator(thrust::make_tuple(
                         thrust::make_permutation_iterator(
                             limiter.end(),
                             owner.end()),
                         thrust::make_permutation_iterator(
                             maxVsf.end(),
                             owner.end()),
                         thrust::make_permutation_iterator(
                             minVsf.end(),
                             owner.end()),
                         Cf.field().end(),
                         thrust::make_permutation_iterator(
                             C.field().end(),
                             owner.end()),
                         thrust::make_permutation_iterator(
                             g.field().end(),
                             owner.end()))),
                     cellLimitedGradcalcGradFunctor<gpucellLimitedGrad>(*this)); //static_cast<const gpucellLimitedGrad &>

    thrust::for_each(thrust::make_zip_iterator(thrust::make_tuple(
                         thrust::make_permutation_iterator(
                             limiter.begin(),
                             neighbour.begin()),
                         thrust::make_permutation_iterator(
                             maxVsf.begin(),
                             neighbour.begin()),
                         thrust::make_permutation_iterator(
                             minVsf.begin(),
                             neighbour.begin()),
                         Cf.field().begin(),
                         thrust::make_permutation_iterator(
                             C.field().begin(),
                             neighbour.begin()),
                         thrust::make_permutation_iterator(
                             g.field().begin(),
                             neighbour.begin()))),
                     thrust::make_zip_iterator(thrust::make_tuple(
                         thrust::make_permutation_iterator(
                             limiter.end(),
                             neighbour.end()),
                         thrust::make_permutation_iterator(
                             maxVsf.end(),
                             neighbour.end()),
                         thrust::make_permutation_iterator(
                             minVsf.end(),
                             neighbour.end()),
                         Cf.field().end(),
                         thrust::make_permutation_iterator(
                             C.field().end(),
                             neighbour.end()),
                         thrust::make_permutation_iterator(
                             g.field().end(),
                             neighbour.end()))),
                     cellLimitedGradcalcGradFunctor<gpucellLimitedGrad>(*this));

    forAll(bsf, patchi) 
    {
       const labelgpuList& pOwner = mesh.boundary()[patchi].gpuFaceCells();
        const vectorgpuField& pCf = Cf.boundaryField()[patchi];


        thrust::for_each(thrust::make_zip_iterator(thrust::make_tuple(
                         thrust::make_permutation_iterator(
                             limiter.begin(),
                             pOwner.begin()),
                         thrust::make_permutation_iterator(
                             maxVsf.begin(),
                             pOwner.begin()),
                         thrust::make_permutation_iterator(
                             minVsf.begin(),
                             pOwner.begin()),
                         pCf.begin(),
                         thrust::make_permutation_iterator(
                             C.field().begin(),
                             pOwner.begin()),
                         thrust::make_permutation_iterator(
                             g.field().begin(),
                             pOwner.begin()))),
                     thrust::make_zip_iterator(thrust::make_tuple(
                         thrust::make_permutation_iterator(
                             limiter.end(),
                             pOwner.end()),
                         thrust::make_permutation_iterator(
                             maxVsf.end(),
                             pOwner.end()),
                         thrust::make_permutation_iterator(
                             minVsf.end(),
                             pOwner.end()),
                         pCf.end(),
                         thrust::make_permutation_iterator(
                             C.field().end(),
                             pOwner.end()),
                         thrust::make_permutation_iterator(
                             g.field().end(),
                             pOwner.end()))),
                     cellLimitedGradcalcGradFunctor<gpucellLimitedGrad>(*this));

    }

    if (fv::debug)
    {
        Info<< "gradient limiter for: " << vsf.name()
            << " max = " << gMax(limiter)
            << " min = " << gMin(limiter)
            << " average: " << gAverage(limiter) << endl;
    }

    limitGradient(limiter, g);
    g.correctBoundaryConditions();
    gpugaussGrad<Type>::correctBoundaryConditions(vsf, g);

    return tGrad;
}


// ************************************************************************* //
