/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2019 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpufvMatrix.H"
#include "cyclicFvPatchgpuField.H"
#include "transformField.H"
#include "volgpuFields.H"

// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

template<class Type>
Foam::cyclicFvPatchgpuField<Type>::cyclicFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF
)
:
    coupledFvPatchgpuField<Type>(p, iF),
    cyclicPatch_(refCast<const cyclicgpuFvPatch>(p))
{}


template<class Type>
Foam::cyclicFvPatchgpuField<Type>::cyclicFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const dictionary& dict,
    const bool valueRequired
)
:
    coupledFvPatchgpuField<Type>(p, iF, dict, false), // Pass no valueRequired
    cyclicPatch_(refCast<const cyclicgpuFvPatch>(p, dict))
{
    if (!isA<cyclicgpuFvPatch>(p))
    {
        FatalIOErrorInFunction(dict)
            << "    patch type '" << p.type()
            << "' not constraint type '" << typeName << "'"
            << "\n    for patch " << p.name()
            << " of field " << this->internalField().name()
            << " in file " << this->internalField().objectPath()
            << exit(FatalIOError);
    }

    if (valueRequired)
    {
        this->evaluate(Pstream::commsTypes::blocking);
    }
}


template<class Type>
Foam::cyclicFvPatchgpuField<Type>::cyclicFvPatchgpuField
(
    const cyclicFvPatchgpuField<Type>& ptf,
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const fvPatchgpuFieldMapper& mapper
)
:
    coupledFvPatchgpuField<Type>(ptf, p, iF, mapper),
    cyclicPatch_(refCast<const cyclicgpuFvPatch>(p))
{
    if (!isA<cyclicgpuFvPatch>(this->patch()))
    {
        FatalErrorInFunction
            << "' not constraint type '" << typeName << "'"
            << "\n    for patch " << p.name()
            << " of field " << this->internalField().name()
            << " in file " << this->internalField().objectPath()
            << exit(FatalError);
    }
}


template<class Type>
Foam::cyclicFvPatchgpuField<Type>::cyclicFvPatchgpuField
(
    const cyclicFvPatchgpuField<Type>& ptf
)
:
    cyclicLduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(ptf),
    cyclicPatch_(ptf.cyclicPatch_)
{}


template<class Type>
Foam::cyclicFvPatchgpuField<Type>::cyclicFvPatchgpuField
(
    const cyclicFvPatchgpuField<Type>& ptf,
    const DimensionedgpuField<Type, gpuvolMesh>& iF
)
:
    coupledFvPatchgpuField<Type>(ptf, iF),
    cyclicPatch_(ptf.cyclicPatch_)
{}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

template<class Type>
Foam::tmp<Foam::gpuField<Type>>
Foam::cyclicFvPatchgpuField<Type>::patchNeighbourField() const
{
    const gpuField<Type>& iField = this->primitiveField();
    const labelgpuList& nbrFaceCells =
        cyclicPatch().cyclicPatch().neighbPatch().gpuFaceCells();

    tmp<gpuField<Type>> tpnf(new gpuField<Type>(this->size()));
    gpuField<Type>& pnf = tpnf.ref();


    if (doTransform())
    {
        tensor t = gpuForwardT().first();

        thrust::transform
        (
            thrust::make_permutation_iterator
            (
                iField.begin(),
                nbrFaceCells.begin()
            ),
            thrust::make_permutation_iterator
            (
                iField.begin(),
                nbrFaceCells.end()
            ),
            pnf.begin(),
            transformBinaryFunctionSFFunctor<tensor,Type,Type>(t)
        );
    }
    else
    {
        thrust::copy
        (
            thrust::make_permutation_iterator
            (
                iField.begin(),
                nbrFaceCells.begin()
            ),
            thrust::make_permutation_iterator
            (
                iField.begin(),
                nbrFaceCells.end()
            ),
            pnf.begin()
        );
    }

    return tpnf;
}


template<class Type>
const Foam::cyclicFvPatchgpuField<Type>&
Foam::cyclicFvPatchgpuField<Type>::neighbourPatchField() const
{
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& fld =
    static_cast<const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>&>
    (
        this->primitiveField()
    );

    return refCast<const cyclicFvPatchgpuField<Type>>
    (
        fld.boundaryField()[this->cyclicPatch().neighbPatchID()]
    );
}




template<class Type>
void Foam::cyclicFvPatchgpuField<Type>::updateInterfaceMatrix
(
    scalargpuField& result,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const scalargpuField& psiInternal,
    const scalargpuField& coeffs,
    const direction cmpt,
    const Pstream::commsTypes commsType
) const
{
    const labelgpuList& nbrFaceCells =
        lduAddr.patchAddr
        (
            this->cyclicPatch().neighbPatchID()
        );

    scalargpuField pnf(psiInternal, nbrFaceCells);

    // Transform according to the transformation tensors
    transformCoupleField(pnf, cmpt);

    coupledFvPatchgpuField<Type>::updateInterfaceMatrix(result, coeffs, pnf, !add);
}

template<class Type>
void Foam::cyclicFvPatchgpuField<Type>::updateInterfaceMatrix
(
    gpuField<Type>& result,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const gpuField<Type>& psiInternal,
    const scalargpuField& coeffs,
    const Pstream::commsTypes
) const
{
    const labelgpuList& nbrFaceCells =
        lduAddr.patchAddr
        (
            this->cyclicPatch().neighbPatchID()
        );

    gpuField<Type> pnf(psiInternal, nbrFaceCells);

    // Transform according to the transformation tensors
    transformCoupleField(pnf);

    // Multiply the field by coefficients and add into the result
    this->addToInternalField(result, !add, patchId, lduAddr, coeffs, pnf);
}

template<class Type>
void Foam::cyclicFvPatchgpuField<Type>::write(Ostream& os) const
{
    fvPatchgpuField<Type>::write(os);
}


template<class Type>
void Foam::cyclicFvPatchgpuField<Type>::manipulateMatrix
(
    gpufvMatrix<Type>& matrix,
    const label mat,
    const direction cmpt
)
{
    if (this->cyclicPatch().owner())
    {
        label index = this->patch().index();

        const label globalPatchID =
            matrix.lduMeshAssembly().patchLocalToGlobalMap()[mat][index];

        const gpuField<scalar> intCoeffsCmpt
        (
            matrix.internalCoeffs()[globalPatchID].component(cmpt)
        );

        const gpuField<scalar> boundCoeffsCmpt
        (
            matrix.boundaryCoeffs()[globalPatchID].component(cmpt)
        );

        const labelgpuList& u = matrix.lduAddr().upperAddr();
        const labelgpuList& l = matrix.lduAddr().lowerAddr();

        const labelList& faceMap =
            matrix.lduMeshAssembly().faceBoundMap()[mat][index];

        forAll (faceMap, faceI)
        {
            label globalFaceI = faceMap[faceI];

            const scalar boundCorr = -boundCoeffsCmpt.get(faceI);
            const scalar intCorr = -intCoeffsCmpt.get(faceI);

            const scalar vu =  matrix.gpuUpper().get(globalFaceI);
            const scalar vdu = matrix.gpuDiag().get(u.get(globalFaceI));
            const scalar vdl = matrix.gpuDiag().get(l.get(globalFaceI));

            matrix.gpuUpper().set(globalFaceI, vu + boundCorr);
            matrix.gpuDiag().set(u.get(globalFaceI), vdu - intCorr);
            matrix.gpuDiag().set(l.get(globalFaceI), vdl - boundCorr);

            if (matrix.asymmetric())
            {   
                const scalar vl =  matrix.gpuLower().get(globalFaceI);
                matrix.gpuLower().set(globalFaceI, vl + intCorr);
            }
        }

        if (matrix.psi(mat).mesh().hostmesh().fluxRequired(this->internalField().name()))
        {
            matrix.internalCoeffs().set
            (
                globalPatchID, intCoeffsCmpt*pTraits<Type>::one
            );
            matrix.boundaryCoeffs().set
            (
                globalPatchID, boundCoeffsCmpt*pTraits<Type>::one
            );

            const label nbrPathID = this->cyclicPatch().neighbPatchID();

            const label nbrGlobalPatchID =
                matrix.lduMeshAssembly().patchLocalToGlobalMap()[mat][nbrPathID];

            matrix.internalCoeffs().set
            (
                nbrGlobalPatchID, intCoeffsCmpt*pTraits<Type>::one

            );
            matrix.boundaryCoeffs().set
            (
                nbrGlobalPatchID, boundCoeffsCmpt*pTraits<Type>::one
            );
        }
    }
}

// ************************************************************************* //
