/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2016 OpenFOAM Foundation
    Copyright (C) 2019-2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpusurfaceInterpolationScheme.H"
#include "volgpuFields.H"
#include "surfacegpuFields.H"
#include "geometricOneField.H"
#include "coupledFvPatchgpuField.H"

// * * * * * * * * * * * * * * * * * Selectors * * * * * * * * * * * * * * * //

template<class Type>
Foam::tmp<Foam::gpusurfaceInterpolationScheme<Type>>
Foam::gpusurfaceInterpolationScheme<Type>::New
(
    const gpufvMesh& mesh,
    Istream& schemeData
)
{
    if (schemeData.eof())
    {
        FatalIOErrorInFunction(schemeData)
            << "Discretisation scheme not specified\n\n"
            << "Valid schemes:\n"
            << MeshConstructorTablePtr_->sortedToc()
            << exit(FatalIOError);
    }

    const word schemeName(schemeData);

    if (gpusurfaceInterpolation::debug || gpusurfaceInterpolationScheme<Type>::debug)
    {
        InfoInFunction << "Discretisation scheme = " << schemeName << endl;
    }

    auto* ctorPtr = MeshConstructorTable(schemeName);

    if (!ctorPtr)
    {
        FatalIOErrorInLookup
        (
            schemeData,
            "discretisation",
            schemeName,
            *MeshConstructorTablePtr_
        ) << exit(FatalIOError);
    }

    return ctorPtr(mesh, schemeData);
}


template<class Type>
Foam::tmp<Foam::gpusurfaceInterpolationScheme<Type>>
Foam::gpusurfaceInterpolationScheme<Type>::New
(
    const gpufvMesh& mesh,
    const surfaceScalargpuField& faceFlux,
    Istream& schemeData
)
{
    if (schemeData.eof())
    {
        FatalIOErrorInFunction(schemeData)
            << "Discretisation scheme not specified"
            << endl << endl
            << "Valid schemes are :" << endl
            << MeshConstructorTablePtr_->sortedToc()
            << exit(FatalIOError);
    }

    const word schemeName(schemeData);

    if (gpusurfaceInterpolation::debug || gpusurfaceInterpolationScheme<Type>::debug)
    {
        InfoInFunction << "Discretisation scheme = " << schemeName << endl;
    }

    auto* ctorPtr = MeshFluxConstructorTable(schemeName);

    if (!ctorPtr)
    {
        FatalIOErrorInLookup
        (
            schemeData,
            "discretisation",
            schemeName,
            *MeshFluxConstructorTablePtr_
        ) << exit(FatalIOError);
    }

    return ctorPtr(mesh, faceFlux, schemeData);
}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //
namespace Foam
{
template<class Type>
struct surfaceInterpolationSchemeInterpolateYFunctor
{
    __host__ __device__
    Type operator()(const thrust::tuple<scalar,Type,scalar,Type>& t)
    {
        return thrust::get<0>(t)*thrust::get<1>(t) + thrust::get<2>(t)*thrust::get<3>(t);
    }
};
}

template<class Type>
Foam::tmp<Foam::GeometricgpuField<Type, Foam::fvsPatchgpuField, Foam::gpusurfaceMesh>>
Foam::gpusurfaceInterpolationScheme<Type>::interpolate
(
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vf,
    const tmp<surfaceScalargpuField>& tlambdas,
    const tmp<surfaceScalargpuField>& tys
)
{
    if (gpusurfaceInterpolation::debug)
    {
        InfoInFunction
            << "Interpolating "
            << vf.type() << " "
            << vf.name()
            << " from cells to faces without explicit correction"
            << endl;
    }

    const surfaceScalargpuField& lambdas = tlambdas();
    const surfaceScalargpuField& ys = tys();

    const gpuField<Type>& vfi =vf;
    const scalargpuField& lambda = lambdas;
    const scalargpuField& y = ys;

    const gpufvMesh& mesh = vf.mesh();
    const labelgpuList& P = mesh.owner();
    const labelgpuList& N = mesh.neighbour();

    tmp<GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>> tsf
    (
        new GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>
        (
            IOobject
            (
                "interpolate("+vf.name()+')',
                vf.instance(),
                vf.db()
            ),
            mesh,
            vf.dimensions()
        )
    );
    GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>& sf = tsf.ref();

    gpuField<Type>& sfi = sf.primitiveFieldRef();

    thrust::transform
    (
        thrust::make_zip_iterator(thrust::make_tuple
        (
            lambda.begin(),
            thrust::make_permutation_iterator(vfi.begin(),P.begin()),
            y.begin(),
            thrust::make_permutation_iterator(vfi.begin(),N.begin())
        )),
        thrust::make_zip_iterator(thrust::make_tuple
        (
            lambda.begin()+P.size(),
            thrust::make_permutation_iterator(vfi.begin(),P.end()),
            y.begin()+P.size(),
            thrust::make_permutation_iterator(vfi.begin(),N.end())
        )),
        sfi.begin(),
        surfaceInterpolationSchemeInterpolateYFunctor<Type>()
    );

    // Interpolate across coupled patches using given lambdas and ys
    typename GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>::
        Boundary& sfbf = sf.boundaryFieldRef();

    forAll(lambdas.boundaryField(), pi)
    {
        const fvsPatchScalargpuField& pLambda = lambdas.boundaryField()[pi];
        const fvsPatchScalargpuField& pY = ys.boundaryField()[pi];

        if (vf.boundaryField()[pi].coupled())
        {
            sfbf[pi] =
                pLambda*vf.boundaryField()[pi].patchInternalField()
              + pY*vf.boundaryField()[pi].patchNeighbourField();
        }
        else
        {
            sfbf[pi] = vf.boundaryField()[pi];
        }
    }

    tlambdas.clear();
    tys.clear();

    return tsf;
}


namespace Foam
{
template<class Type,class RetType>
struct surfaceInterpolationSchemeInterpolateFunctor{
    __host__ __device__
    RetType operator()(const thrust::tuple<vector,scalar,Type,Type>& t){
        return thrust::get<0>(t) & (thrust::get<1>(t)*(thrust::get<2>(t) - thrust::get<3>(t)) + thrust::get<3>(t));
    }
};
}


template<class Type>
template<class SFType>
Foam::tmp
<
    Foam::GeometricgpuField
    <
        typename Foam::innerProduct<typename SFType::value_type, Type>::type,
        Foam::fvsPatchgpuField,
        Foam::gpusurfaceMesh
    >
>
Foam::gpusurfaceInterpolationScheme<Type>::dotInterpolate
(
    const SFType& Sf,
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vf,
    const tmp<surfaceScalargpuField>& tlambdas
)
{
    if (gpusurfaceInterpolation::debug)
    {
        InfoInFunction
            << "Interpolating "
            << vf.type() << " "
            << vf.name()
            << " from cells to faces without explicit correction"
            << endl;
    }

    typedef typename Foam::innerProduct<typename SFType::value_type, Type>::type
        RetType;

    const surfaceScalargpuField& lambdas = tlambdas();

    const gpuField<Type>& vfi = vf;
    const scalargpuField& lambda = lambdas;

    const gpufvMesh& mesh = vf.mesh();
    const labelgpuList& P = mesh.owner();
    const labelgpuList& N = mesh.neighbour();

    tmp<GeometricgpuField<RetType, fvsPatchgpuField, gpusurfaceMesh>> tsf
    (
        new GeometricgpuField<RetType, fvsPatchgpuField, gpusurfaceMesh>
        (
            IOobject
            (
                "interpolate("+vf.name()+')',
                vf.instance(),
                vf.db()
            ),
            mesh,
            Sf.dimensions()*vf.dimensions()
        )
    );
    GeometricgpuField<RetType, fvsPatchgpuField, gpusurfaceMesh>& sf = tsf.ref();

    gpuField<RetType>& sfi = sf.primitiveFieldRef();

    const typename SFType::Internal& Sfi = Sf();
    const gpuField<vector>& Sfi_d = Sfi;
/*
    for (label fi=0; fi<P.size(); fi++)
    {
        sfi[fi] = Sfi[fi] & (lambda[fi]*(vfi[P[fi]] - vfi[N[fi]]) + vfi[N[fi]]);
    }
*/

    thrust::transform
    (
        thrust::make_zip_iterator(thrust::make_tuple
        (
            Sfi_d.begin(),
			lambda.begin(),
            thrust::make_permutation_iterator(vfi.begin(),P.begin()),
            thrust::make_permutation_iterator(vfi.begin(),N.begin())
        )),
        thrust::make_zip_iterator(thrust::make_tuple
        (
            Sfi_d.begin()+P.size(),
			lambda.begin()+P.size(),
            thrust::make_permutation_iterator(vfi.begin(),P.end()),
            thrust::make_permutation_iterator(vfi.begin(),N.end())
        )),
        sfi.begin(),
        surfaceInterpolationSchemeInterpolateFunctor<Type,RetType>()
    );

    typename GeometricgpuField<RetType, fvsPatchgpuField, gpusurfaceMesh>::
        Boundary& sfbf = sf.boundaryFieldRef();

    forAll(lambdas.boundaryField(), pi)
    {
        const fvsPatchScalargpuField& pLambda = lambdas.boundaryField()[pi];
        const typename SFType::Patch& pSf = Sf.boundaryField()[pi];
        fvsPatchgpuField<RetType>& psf = sfbf[pi];

        if (vf.boundaryField()[pi].coupled())
        {
            psf =
                pSf
              & (
                    pLambda*vf.boundaryField()[pi].patchInternalField()
                  + (1.0 - pLambda)*vf.boundaryField()[pi].patchNeighbourField()
                );
        }
        else
        {
            psf = pSf & vf.boundaryField()[pi];
        }
    }

    tlambdas.clear();

//    tsf.ref().oriented() = Sf.oriented();

    return tsf;
}

namespace Foam
{
template<class Type>
struct surfaceInterpolationSchemeInterpolateOldFunctor{
    __host__ __device__
    Type operator()(const thrust::tuple<scalar,Type,Type>& t){
        return thrust::get<0>(t)*(thrust::get<1>(t) - thrust::get<2>(t)) + thrust::get<2>(t);
    }
};
}


template<class Type>
Foam::tmp<Foam::GeometricgpuField<Type, Foam::fvsPatchgpuField, Foam::gpusurfaceMesh> >
Foam::gpusurfaceInterpolationScheme<Type>::dotInterpolate
(
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vf,
    const tmp<surfaceScalargpuField>& tlambdas
)
{
    if (gpusurfaceInterpolation::debug)
    {
        Info<< "surfaceInterpolationScheme<Type>::interpolate"
               "(const GeometricField<Type, fvPatchField, volMesh>&, "
               "const tmp<surfaceScalarField>&) : "
               "interpolating "
            << vf.type() << " "
            << vf.name()
            << " from cells to faces "
               "without explicit correction"
            << endl;
    }

    const surfaceScalargpuField& lambdas = tlambdas();

    const gpuField<Type>& vfi = vf;
    const scalargpuField& lambda = lambdas;

    const gpufvMesh& mesh = vf.mesh();
    const labelgpuList& P = mesh.owner();
    const labelgpuList& N = mesh.neighbour();

    tmp<GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh> > tsf
    (
        new GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>
        (
            IOobject
            (
                "interpolate("+vf.name()+')',
                vf.instance(),
                vf.db()
            ),
            mesh,
            vf.dimensions()
        )
    );
    GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>& sf = tsf.ref();

    gpuField<Type>& sfi = sf.primitiveFieldRef();
    
    thrust::transform
    (
        thrust::make_zip_iterator(thrust::make_tuple
        (
            lambda.begin(),
            thrust::make_permutation_iterator(vfi.begin(),P.begin()),
            thrust::make_permutation_iterator(vfi.begin(),N.begin())
        )),
        thrust::make_zip_iterator(thrust::make_tuple
        (
            lambda.begin()+P.size(),
            thrust::make_permutation_iterator(vfi.begin(),P.end()),
            thrust::make_permutation_iterator(vfi.begin(),N.end())
        )),
        sfi.begin(),
        surfaceInterpolationSchemeInterpolateOldFunctor<Type>()
    );

    // Interpolate across coupled patches using given lambdas
    forAll(lambdas.boundaryField(), pi)
    {
        const fvsPatchScalargpuField& pLambda = lambdas.boundaryField()[pi];

        if (vf.boundaryField()[pi].coupled())
        {
            sf.boundaryFieldRef()[pi] =
                pLambda*vf.boundaryField()[pi].patchInternalField()
             + (1.0 - pLambda)*vf.boundaryField()[pi].patchNeighbourField();
        }
        else
        {
            sf.boundaryFieldRef()[pi] = vf.boundaryField()[pi];
        }
    }

    tlambdas.clear();

    return tsf;
}

template<class Type>
Foam::tmp<Foam::GeometricgpuField<Type, Foam::fvsPatchgpuField, Foam::gpusurfaceMesh>>
Foam::gpusurfaceInterpolationScheme<Type>::interpolate
(
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vf,
    const tmp<surfaceScalargpuField>& tlambdas
)
{
    //return dotInterpolate(geometricOneField(), vf, tlambdas);
    return dotInterpolate(vf, tlambdas);
}


template<class Type>
Foam::tmp
<
    Foam::GeometricgpuField
    <
        typename Foam::innerProduct<Foam::vector, Type>::type,
        Foam::fvsPatchgpuField,
        Foam::gpusurfaceMesh
    >
>
Foam::gpusurfaceInterpolationScheme<Type>::dotInterpolate
(
    const surfaceVectorgpuField& Sf,
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vf
) const
{
    if (gpusurfaceInterpolation::debug)
    {
        InfoInFunction
            << "Interpolating "
            << vf.type() << " "
            << vf.name()
            << " from cells to faces"
            << endl;
    }

    tmp
    <
        GeometricgpuField
        <
            typename Foam::innerProduct<Foam::vector, Type>::type,
            fvsPatchgpuField,
            gpusurfaceMesh
        >
    > tsf = dotInterpolate(Sf, vf, weights(vf));//Sf & dotInterpolate(vf, weights(vf));

    tsf.ref().oriented() = Sf.oriented();

    if (corrected())
    {
        tsf.ref() += Sf & correction(vf);
    }

    return tsf;
}


template<class Type>
Foam::tmp
<
    Foam::GeometricgpuField
    <
        typename Foam::innerProduct<Foam::vector, Type>::type,
        Foam::fvsPatchgpuField,
        Foam::gpusurfaceMesh
    >
>
Foam::gpusurfaceInterpolationScheme<Type>::dotInterpolate
(
    const surfaceVectorgpuField& Sf,
    const tmp<GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>>& tvf
) const
{
    tmp
    <
        GeometricgpuField
        <
            typename Foam::innerProduct<Foam::vector, Type>::type,
            fvsPatchgpuField,
            gpusurfaceMesh
        >
    > tSfDotinterpVf = dotInterpolate(Sf, tvf());

    tvf.clear();
    return tSfDotinterpVf;
}


template<class Type>
Foam::tmp<Foam::GeometricgpuField<Type, Foam::fvsPatchgpuField, Foam::gpusurfaceMesh>>
Foam::gpusurfaceInterpolationScheme<Type>::interpolate
(
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& vf
) const
{
    if (gpusurfaceInterpolation::debug)
    {
        InfoInFunction
            << "Interpolating "
            << vf.type() << " "
            << vf.name()
            << " from cells to faces"
            << endl;
    }

    tmp<GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>> tsf
        = interpolate(vf, weights(vf));

    if (corrected())
    {
        tsf.ref() += correction(vf);
    }

    return tsf;
}


template<class Type>
Foam::tmp<Foam::GeometricgpuField<Type, Foam::fvsPatchgpuField, Foam::gpusurfaceMesh>>
Foam::gpusurfaceInterpolationScheme<Type>::interpolate
(
    const tmp<GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>>& tvf
) const
{
    tmp<GeometricgpuField<Type, fvsPatchgpuField, gpusurfaceMesh>> tinterpVf
        = interpolate(tvf());
    tvf.clear();
    return tinterpVf;
}


// ************************************************************************* //
