/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2013-2017 OpenFOAM Foundation
    Copyright (C) 2019 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpufvMatrix.H"
#include "cyclicACMIFvPatchgpuField.H"
#include "transformField.H"
#include "gpulduAddressingFunctors.H"

// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

template<class Type>
Foam::cyclicACMIFvPatchgpuField<Type>::cyclicACMIFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF
)
:
    cyclicACMILduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(p, iF),
    cyclicACMIPatch_(refCast<const cyclicACMIgpuFvPatch>(p))
{}


template<class Type>
Foam::cyclicACMIFvPatchgpuField<Type>::cyclicACMIFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const dictionary& dict
)
:
    cyclicACMILduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(p, iF, dict, dict.found("value")),
    cyclicACMIPatch_(refCast<const cyclicACMIgpuFvPatch>(p, dict))
{
    if (!isA<cyclicACMIgpuFvPatch>(p))
    {
        FatalIOErrorInFunction(dict)
            << "    patch type '" << p.type()
            << "' not constraint type '" << typeName << "'"
            << "\n    for patch " << p.name()
            << " of field " << this->internalField().name()
            << " in file " << this->internalField().objectPath()
            << exit(FatalIOError);
    }

    if (!dict.found("value") && this->coupled())
    {
        // Extra check: make sure that the non-overlap patch is before
        // this so it has actually been read - evaluate will crash otherwise

        const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& fld =
            static_cast<const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>&>
            (
                this->primitiveField()
            );
        if (!fld.boundaryField().set(cyclicACMIPatch_.nonOverlapPatchID()))
        {
            FatalIOErrorInFunction(dict)
                << "    patch " << p.name()
                << " of field " << this->internalField().name()
                << " refers to non-overlap patch "
                << cyclicACMIPatch_.cyclicACMIPatch().nonOverlapPatchName()
                << " which is not constructed yet." << nl
                << "    Either supply an initial value or change the ordering"
                << " in the file"
                << exit(FatalIOError);
        }

        this->evaluate(Pstream::commsTypes::blocking);
    }
}


template<class Type>
Foam::cyclicACMIFvPatchgpuField<Type>::cyclicACMIFvPatchgpuField
(
    const cyclicACMIFvPatchgpuField<Type>& ptf,
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const fvPatchgpuFieldMapper& mapper
)
:
    cyclicACMILduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(ptf, p, iF, mapper),
    cyclicACMIPatch_(refCast<const cyclicACMIgpuFvPatch>(p))
{
    if (!isA<cyclicACMIgpuFvPatch>(this->patch()))
    {
        FatalErrorInFunction
            << "' not constraint type '" << typeName << "'"
            << "\n    for patch " << p.name()
            << " of field " << this->internalField().name()
            << " in file " << this->internalField().objectPath()
            << exit(FatalError);
    }
}



template<class Type>
Foam::cyclicACMIFvPatchgpuField<Type>::cyclicACMIFvPatchgpuField
(
    const cyclicACMIFvPatchgpuField<Type>& ptf
)
:
    cyclicACMILduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(ptf),
    cyclicACMIPatch_(ptf.cyclicACMIPatch_)
{}


template<class Type>
Foam::cyclicACMIFvPatchgpuField<Type>::cyclicACMIFvPatchgpuField
(
    const cyclicACMIFvPatchgpuField<Type>& ptf,
    const DimensionedgpuField<Type, gpuvolMesh>& iF
)
:
    cyclicACMILduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(ptf, iF),
    cyclicACMIPatch_(ptf.cyclicACMIPatch_)
{}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

template<class Type>
bool Foam::cyclicACMIFvPatchgpuField<Type>::coupled() const
{
    return cyclicACMIPatch_.coupled();
}


template<class Type>
Foam::tmp<Foam::gpuField<Type>>
Foam::cyclicACMIFvPatchgpuField<Type>::patchNeighbourField() const
{
    const gpuField<Type>& iField = this->primitiveField();
    //const cyclicACMIPolyPatch& cpp = cyclicACMIPatch_.cyclicACMIPatch();

    // By pass polyPatch to get nbrId. Instead use cyclicAMIFvPatch virtual
    // neighbPatch()
    const cyclicACMIgpuFvPatch& neighbPatch = cyclicACMIPatch_.neighbPatch();
    const labelgpuList& nbrFaceCells = neighbPatch.gpuFaceCells();

    tmp<gpuField<Type>> tpnf
    (
        cyclicACMIPatch_.interpolate
        (
            gpuField<Type>
            (
                iField,
                nbrFaceCells
                //cpp.neighbPatch().faceCells()
            )
        )
    );

    if (doTransform())
    {
        tpnf.ref() = transform(gpuForwardT(), tpnf());
    }

    return tpnf;
}


template<class Type>
const Foam::cyclicACMIFvPatchgpuField<Type>&
Foam::cyclicACMIFvPatchgpuField<Type>::neighbourPatchField() const
{
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& fld =
        static_cast<const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>&>
        (
            this->primitiveField()
        );

    return refCast<const cyclicACMIFvPatchgpuField<Type>>
    (
        fld.boundaryField()[cyclicACMIPatch_.neighbPatchID()]
    );
}


template<class Type>
const Foam::fvPatchgpuField<Type>&
Foam::cyclicACMIFvPatchgpuField<Type>::nonOverlapPatchField() const
{
    const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>& fld =
        static_cast<const GeometricgpuField<Type, fvPatchgpuField, gpuvolMesh>&>
        (
            this->primitiveField()
        );

    // WIP: Needs to re-direct nonOverlapPatchID to new patchId for assembly?
    return fld.boundaryField()[cyclicACMIPatch_.nonOverlapPatchID()];
}


template<class Type>
void Foam::cyclicACMIFvPatchgpuField<Type>::updateInterfaceMatrix
(
    scalargpuField& result,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const scalargpuField& psiInternal,
    const scalargpuField& coeffs,
    const direction cmpt,
    const Pstream::commsTypes
) const
{
    // note: only applying coupled contribution

//    const labelgpuList& nbrFaceCellsCoupled =
//        lduAddr.patchAddr
//        (
//            cyclicACMIPatch_.cyclicACMIPatch().neighbPatchID()
//        );

    const labelgpuList& nbrFaceCellsCoupled =
        lduAddr.patchAddr(cyclicACMIPatch_.neighbPatchID());

    scalargpuField pnf(psiInternal, nbrFaceCellsCoupled);

    // Transform according to the transformation tensors
    transformCoupleField(pnf, cmpt);

    pnf = cyclicACMIPatch_.interpolate(pnf);

    coupledFvPatchgpuField<Type>::updateInterfaceMatrix(result, coeffs, pnf, !add);
}

template<class Type>
void Foam::cyclicACMIFvPatchgpuField<Type>::updateInterfaceMatrix
(
    gpuField<Type>& result,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const gpuField<Type>& psiInternal,
    const scalargpuField& coeffs,
    const Pstream::commsTypes
) const
{
    // note: only applying coupled contribution

    const labelgpuList& nbrFaceCellsCoupled =
        lduAddr.patchAddr(cyclicACMIPatch_.neighbPatchID());

    gpuField<Type> pnf(psiInternal, nbrFaceCellsCoupled);

    // Transform according to the transformation tensors
    transformCoupleField(pnf);

    pnf = cyclicACMIPatch_.interpolate(pnf);

    this->addToInternalField(result, !add, patchId, lduAddr, coeffs, pnf);
}


template<class Type>
void Foam::cyclicACMIFvPatchgpuField<Type>::manipulateMatrix
(
    gpufvMatrix<Type>& matrix
)
{
    const scalargpuField& mask = cyclicACMIPatch_.cyclicACMIPatch().gpumask();

    // Nothing to be done by the AMI, but re-direct to non-overlap patch
    // with non-overlap patch weights
    const fvPatchgpuField<Type>& npf = nonOverlapPatchField();

    const_cast<fvPatchgpuField<Type>&>(npf).manipulateMatrix(matrix, 1.0 - mask);
}


template<class Type>
void Foam::cyclicACMIFvPatchgpuField<Type>::manipulateMatrix
(
    gpufvMatrix<Type>& matrix,
    const label mat,
    const direction cmpt
)
{
    if (this->cyclicACMIPatch().owner())
    {
        label index = this->patch().index();

        const label globalPatchID =
            matrix.lduMeshAssembly().patchLocalToGlobalMap()[mat][index];

        const gpuField<scalar> intCoeffsCmpt
        (
            matrix.internalCoeffs()[globalPatchID].component(cmpt)
        );

        const gpuField<scalar> boundCoeffsCmpt
        (
            matrix.boundaryCoeffs()[globalPatchID].component(cmpt)
        );

        tmp<gpuField<scalar>> tintCoeffs(coeffs(matrix, intCoeffsCmpt, mat));
        tmp<gpuField<scalar>> tbndCoeffs(coeffs(matrix, boundCoeffsCmpt, mat));
        const gpuField<scalar>& intCoeffs = tintCoeffs.ref();
        const gpuField<scalar>& bndCoeffs = tbndCoeffs.ref();

        const labelgpuList& u = matrix.lduAddr().upperAddr();
        const labelgpuList& l = matrix.lduAddr().lowerAddr();

        label subFaceI = 0;

        const labelList& faceMap =
            matrix.lduMeshAssembly().faceBoundMap()[mat][index];

        forAll (faceMap, j)
        {
            label globalFaceI = faceMap[j];

            const scalar boundCorr = -bndCoeffs.get(subFaceI);
            const scalar intCorr = -intCoeffs.get(subFaceI);

            const scalar vu =  matrix.gpuUpper().get(globalFaceI);
            const scalar vdu = matrix.gpuDiag().get(u.get(globalFaceI));
            const scalar vdl = matrix.gpuDiag().get(l.get(globalFaceI));

            matrix.gpuUpper().set(globalFaceI, vu + boundCorr);
            matrix.gpuDiag().set(u.get(globalFaceI), vdu - intCorr);
            matrix.gpuDiag().set(l.get(globalFaceI), vdl - boundCorr);

            if (matrix.asymmetric())
            {
                const scalar vl =  matrix.gpuLower().get(globalFaceI);
                matrix.gpuLower().set(globalFaceI, vl + intCorr);
            }

            subFaceI++;
        }

        // Set internalCoeffs and boundaryCoeffs in the assembly matrix
        // on clyclicAMI patches to be used in the individual matrix by
        // matrix.flux()
        if (matrix.psi(mat).mesh().hostmesh().fluxRequired(this->internalField().name()))
        {
            matrix.internalCoeffs().set
            (
                globalPatchID, intCoeffs*pTraits<Type>::one
            );
            matrix.boundaryCoeffs().set
            (
                globalPatchID, bndCoeffs*pTraits<Type>::one
            );

            const label nbrPathID =
                cyclicACMIPatch_.cyclicACMIPatch().neighbPatchID();

            const label nbrGlobalPatchID =
                matrix.lduMeshAssembly().patchLocalToGlobalMap()
                [mat][nbrPathID];

            matrix.internalCoeffs().set
            (
                nbrGlobalPatchID, intCoeffs*pTraits<Type>::one
            );
            matrix.boundaryCoeffs().set
            (
                nbrGlobalPatchID, bndCoeffs*pTraits<Type>::one
            );
        }
    }
}


template<class Type>
Foam::tmp<Foam::gpuField<Foam::scalar>>
Foam::cyclicACMIFvPatchgpuField<Type>::coeffs
(
    gpufvMatrix<Type>& matrix,
    const gpuField<scalar>& coeffs,
    const label mat
) const
{
    const label index(this->patch().index());

    const label nSubFaces
    (
        matrix.lduMeshAssembly().cellBoundMap()[mat][index].size()
    );

    gpuField<scalar> mapCoeffs(nSubFaces, Zero);

    const scalarListList& srcWeight =
        cyclicACMIPatch_.cyclicACMIPatch().AMI().srcWeights();

    const scalarField& mask = cyclicACMIPatch_.cyclicACMIPatch().mask();

    const scalar tol = cyclicACMIPolyPatch::tolerance();
    label subFaceI = 0;
    forAll(mask, faceI)
    {
        const scalarList& w = srcWeight[faceI];
        for(label i=0; i<w.size(); i++)
        {
            if (mask[faceI] > tol)
            {
                const label localFaceId =
                    matrix.lduMeshAssembly().facePatchFaceMap()
                    [mat][index][subFaceI];
                mapCoeffs.set(subFaceI, w[i]*coeffs.get(localFaceId));
            }
            subFaceI++;
        }
    }

    return tmp<gpuField<scalar>>(new gpuField<scalar>(mapCoeffs));
}


template<class Type>
void Foam::cyclicACMIFvPatchgpuField<Type>::updateCoeffs()
{
    // Update non-overlap patch - some will implement updateCoeffs, and
    // others will implement evaluate

    // Pass in (1 - mask) to give non-overlap patch the chance to do
    // manipulation of non-face based data

    const scalargpuField& mask = cyclicACMIPatch_.cyclicACMIPatch().gpumask();
    const fvPatchgpuField<Type>& npf = nonOverlapPatchField();
    const_cast<fvPatchgpuField<Type>&>(npf).updateWeightedCoeffs(1.0 - mask);
}


template<class Type>
void Foam::cyclicACMIFvPatchgpuField<Type>::write(Ostream& os) const
{
    fvPatchgpuField<Type>::write(os);
    this->writeEntry("value", os);
}


// ************************************************************************* //
