/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2019-2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "processorFvPatchgpuField.H"
#include "processorgpuFvPatch.H"
#include "demandDrivenData.H"
#include "transformField.H"

#include "gpulduAddressingFunctors.H"

// * * * * * * * * * * * * * * * * Constructors * * * * * * * * * * * * * * //

template<class Type>
Foam::processorFvPatchgpuField<Type>::processorFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF
)
:
    coupledFvPatchgpuField<Type>(p, iF),
    procPatch_(refCast<const processorgpuFvPatch>(p)),
    sendBuf_(0),
    receiveBuf_(0),
    outstandingSendRequest_(-1),
    outstandingRecvRequest_(-1),
    scalarSendBuf_(0),
    scalarReceiveBuf_(0),
    gpuSendBuf_(0),
    gpuReceiveBuf_(0),
    scalargpuSendBuf_(0),
    scalargpuReceiveBuf_(0)
{}


template<class Type>
Foam::processorFvPatchgpuField<Type>::processorFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const gpuField<Type>& f
)
:
    coupledFvPatchgpuField<Type>(p, iF, f),
    procPatch_(refCast<const processorgpuFvPatch>(p)),
    sendBuf_(0),
    receiveBuf_(0),
    outstandingSendRequest_(-1),
    outstandingRecvRequest_(-1),
    scalarSendBuf_(0),
    scalarReceiveBuf_(0),
    gpuSendBuf_(0),
    gpuReceiveBuf_(0),
    scalargpuSendBuf_(0),
    scalargpuReceiveBuf_(0)
{}


template<class Type>
Foam::processorFvPatchgpuField<Type>::processorFvPatchgpuField
(
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const dictionary& dict
)
:
    coupledFvPatchgpuField<Type>(p, iF, dict, dict.found("value")),
    procPatch_(refCast<const processorgpuFvPatch>(p, dict)),
    sendBuf_(0),
    receiveBuf_(0),
    outstandingSendRequest_(-1),
    outstandingRecvRequest_(-1),
    scalarSendBuf_(0),
    scalarReceiveBuf_(0),
    gpuSendBuf_(0),
    gpuReceiveBuf_(0),
    scalargpuSendBuf_(0),
    scalargpuReceiveBuf_(0)
{
    if (!isA<processorgpuFvPatch>(p))
    {
        FatalIOErrorInFunction(dict)
            << "\n    patch type '" << p.type()
            << "' not constraint type '" << typeName << "'"
            << "\n    for patch " << p.name()
            << " of field " << this->internalField().name()
            << " in file " << this->internalField().objectPath()
            << exit(FatalIOError);
    }

    // If the value is not supplied set to the internal field
    if (!dict.found("value"))
    {
        fvPatchgpuField<Type>::operator=(this->patchInternalField());
    }
}


template<class Type>
Foam::processorFvPatchgpuField<Type>::processorFvPatchgpuField
(
    const processorFvPatchgpuField<Type>& ptf,
    const gpufvPatch& p,
    const DimensionedgpuField<Type, gpuvolMesh>& iF,
    const fvPatchgpuFieldMapper& mapper
)
:
    coupledFvPatchgpuField<Type>(ptf, p, iF, mapper),
    procPatch_(refCast<const processorgpuFvPatch>(p)),
    sendBuf_(0),
    receiveBuf_(0),
    outstandingSendRequest_(-1),
    outstandingRecvRequest_(-1),
    scalarSendBuf_(0),
    scalarReceiveBuf_(0),
    gpuSendBuf_(0),
    gpuReceiveBuf_(0),
    scalargpuSendBuf_(0),
    scalargpuReceiveBuf_(0)
{
    if (!isA<processorgpuFvPatch>(this->patch()))
    {
        FatalErrorInFunction
            << "' not constraint type '" << typeName << "'"
            << "\n    for patch " << p.name()
            << " of field " << this->internalField().name()
            << " in file " << this->internalField().objectPath()
            << exit(FatalError);
    }
    if (debug && !ptf.ready())
    {
        FatalErrorInFunction
            << "On patch " << procPatch_.name() << " outstanding request."
            << abort(FatalError);
    }
}


template<class Type>
Foam::processorFvPatchgpuField<Type>::processorFvPatchgpuField
(
    const processorFvPatchgpuField<Type>& ptf
)
:
    processorLduInterfacegpuField(),
    coupledFvPatchgpuField<Type>(ptf),
    procPatch_(refCast<const processorgpuFvPatch>(ptf.patch())),
    sendBuf_(std::move(ptf.sendBuf_)),
    receiveBuf_(std::move(ptf.receiveBuf_)),
    outstandingSendRequest_(-1),
    outstandingRecvRequest_(-1),
    scalarSendBuf_(std::move(ptf.scalarSendBuf_)),
    scalarReceiveBuf_(std::move(ptf.scalarReceiveBuf_)),
    gpuSendBuf_(ptf.gpuSendBuf_),
    gpuReceiveBuf_(ptf.gpuReceiveBuf_),
    scalargpuSendBuf_(ptf.scalargpuSendBuf_),
    scalargpuReceiveBuf_(ptf.scalargpuReceiveBuf_)
{
    if (debug && !ptf.ready())
    {
        FatalErrorInFunction
            << "On patch " << procPatch_.name() << " outstanding request."
            << abort(FatalError);
    }
}


template<class Type>
Foam::processorFvPatchgpuField<Type>::processorFvPatchgpuField
(
    const processorFvPatchgpuField<Type>& ptf,
    const DimensionedgpuField<Type, gpuvolMesh>& iF
)
:
    coupledFvPatchgpuField<Type>(ptf, iF),
    procPatch_(refCast<const processorgpuFvPatch>(ptf.patch())),
    sendBuf_(0),
    receiveBuf_(0),
    outstandingSendRequest_(-1),
    outstandingRecvRequest_(-1),
    scalarSendBuf_(0),
    scalarReceiveBuf_(0),
    gpuSendBuf_(0),
    gpuReceiveBuf_(0),
    scalargpuSendBuf_(0),
    scalargpuReceiveBuf_(0)
{
    if (debug && !ptf.ready())
    {
        FatalErrorInFunction
            << "On patch " << procPatch_.name() << " outstanding request."
            << abort(FatalError);
    }
}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

template<class Type>
Foam::tmp<Foam::gpuField<Type>>
Foam::processorFvPatchgpuField<Type>::patchNeighbourField() const
{
    if (debug && !this->ready())
    {
        FatalErrorInFunction
            << "On patch " << procPatch_.name()
            << " outstanding request."
            << abort(FatalError);
    }
    return *this;
}


template<class Type>
void Foam::processorFvPatchgpuField<Type>::initEvaluate
(
    const Pstream::commsTypes commsType
)
{
    if (Pstream::parRun())
    {
        this->patchInternalField(gpuSendBuf_);

        if
        (
            commsType == Pstream::commsTypes::nonBlocking
         && !Pstream::floatTransfer
        )
        {
            if (!is_contiguous<Type>::value)
            {
                FatalErrorInFunction
                    << "Invalid for non-contiguous data types"
                    << abort(FatalError);
            }

            std::streamsize nBytes = gpuSendBuf_.byteSize();

            Type* receive;
            const Type* send;

            this->setSize(gpuSendBuf_.size());

            if(Pstream::gpuDirectTransfer)
            {
                // Fast path.
                send = gpuSendBuf_.data();
                receive = this->data();
            }
            else
            {
                sendBuf_.setSize(gpuSendBuf_.size());
                receiveBuf_.setSize(sendBuf_.size());
                thrust::copy
                (
                    gpuSendBuf_.begin(),
                    gpuSendBuf_.end(),
                    sendBuf_.begin()
                );

                send = sendBuf_.begin();
                receive = receiveBuf_.begin();
            }


            outstandingRecvRequest_ = UPstream::nRequests();
            UIPstream::read
            (
                Pstream::commsTypes::nonBlocking,
                procPatch_.neighbProcNo(),
                reinterpret_cast<char*>(receive),
                nBytes,
                procPatch_.tag(),
                procPatch_.comm()
            );

            outstandingSendRequest_ = UPstream::nRequests();
            UOPstream::write
            (
                Pstream::commsTypes::nonBlocking,
                procPatch_.neighbProcNo(),
                reinterpret_cast<const char*>(send),
                nBytes,
                procPatch_.tag(),
                procPatch_.comm()
            );
        }
        else
        {
            procPatch_.compressedSend(commsType, gpuSendBuf_);
        }
    }
}


template<class Type>
void Foam::processorFvPatchgpuField<Type>::evaluate
(
    const Pstream::commsTypes commsType
)
{
    if (Pstream::parRun())
    {
        if
        (
            commsType == Pstream::commsTypes::nonBlocking
         && !Pstream::floatTransfer
        )
        {
            // Fast path. Received into *this

            if
            (
                outstandingRecvRequest_ >= 0
             && outstandingRecvRequest_ < Pstream::nRequests()
            )
            {
                UPstream::waitRequest(outstandingRecvRequest_);
            }
            outstandingSendRequest_ = -1;
            outstandingRecvRequest_ = -1;

            if( ! Pstream::gpuDirectTransfer)
            {
                scalargpuReceiveBuf_ = scalarReceiveBuf_;
                thrust::copy
                (
                    receiveBuf_.begin(),
                    receiveBuf_.end(),
                    this->begin()
                );
            }
        }
        else
        {
            procPatch_.compressedReceive<Type>(commsType, *this);
        }

        if (doTransform())
        {
            transform(*this, procPatch_.gpuForwardT(), *this);
        }
    }
}


template<class Type>
Foam::tmp<Foam::gpuField<Type>>
Foam::processorFvPatchgpuField<Type>::snGrad
(
    const scalargpuField& deltaCoeffs
) const
{
    return deltaCoeffs*(*this - this->patchInternalField());
}


template<class Type>
void Foam::processorFvPatchgpuField<Type>::initInterfaceMatrixUpdate
(
    scalargpuField&,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const scalargpuField& psiInternal,
    const scalargpuField&,
    const direction,
    const Pstream::commsTypes commsType
) const
{
    const labelgpuList& faceCells = lduAddr.patchAddr(patchId);

    scalargpuSendBuf_.setSize(this->patch().size());
 
    thrust::copy(thrust::make_permutation_iterator(psiInternal.begin(),faceCells.begin()),
                 thrust::make_permutation_iterator(psiInternal.begin(),faceCells.end()),
                 scalargpuSendBuf_.begin());

    if
    (
        commsType == Pstream::commsTypes::nonBlocking
     && !Pstream::floatTransfer
    )
    {
        // Fast path.
        if (debug && !this->ready())
        {
            FatalErrorInFunction
                << "On patch " << procPatch_.name()
                << " outstanding request."
                << abort(FatalError);
        }

        std::streamsize nBytes = scalargpuSendBuf_.byteSize();

        scalar* receive;
        const scalar* send;

        if(Pstream::gpuDirectTransfer)
        {
            scalargpuField scalargpuReceiveBuf_(scalargpuSendBuf_.size());

            send = scalargpuSendBuf_.data();
            receive = scalargpuReceiveBuf_.data();
        }
        else
        {
            scalarSendBuf_.setSize(scalargpuSendBuf_.size());
            scalarReceiveBuf_.setSize(scalarSendBuf_.size());
            thrust::copy
            (
                scalargpuSendBuf_.begin(),
                scalargpuSendBuf_.end(),
                scalarSendBuf_.begin()
            );

            send = scalarSendBuf_.begin();
            receive = scalarReceiveBuf_.begin();
        }

        scalarReceiveBuf_.setSize(scalarSendBuf_.size());
        outstandingRecvRequest_ = UPstream::nRequests();
        UIPstream::read
        (
            Pstream::commsTypes::nonBlocking,
            procPatch_.neighbProcNo(),
            reinterpret_cast<char*>(receive),
            nBytes,
            procPatch_.tag(),
            procPatch_.comm()
        );

        outstandingSendRequest_ = UPstream::nRequests();
        UOPstream::write
        (
            Pstream::commsTypes::nonBlocking,
            procPatch_.neighbProcNo(),
            reinterpret_cast<const char*>(send),
            nBytes,
            procPatch_.tag(),
            procPatch_.comm()
        );
    }
    else
    {
        procPatch_.compressedSend(commsType, scalargpuSendBuf_);
    }

    const_cast<processorFvPatchgpuField<Type>&>(*this).updatedMatrix() = false;
}


template<class Type>
void Foam::processorFvPatchgpuField<Type>::updateInterfaceMatrix
(
    scalargpuField& result,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const scalargpuField&,
    const scalargpuField& coeffs,
    const direction cmpt,
    const Pstream::commsTypes commsType
) const
{
    if (this->updatedMatrix())
    {
        return;
    }

    if
    (
        commsType == Pstream::commsTypes::nonBlocking
     && !Pstream::floatTransfer
    )
    {
        // Fast path.
        if
        (
            outstandingRecvRequest_ >= 0
         && outstandingRecvRequest_ < Pstream::nRequests()
        )
        {
            UPstream::waitRequest(outstandingRecvRequest_);
        }
        // Recv finished so assume sending finished as well.
        outstandingSendRequest_ = -1;
        outstandingRecvRequest_ = -1;
        // Consume straight from scalarReceiveBuf_

        if( ! Pstream::gpuDirectTransfer)
        {
            scalargpuReceiveBuf_ = scalarReceiveBuf_;
        }

        if (!std::is_arithmetic<Type>::value)
        {
            // Transform non-scalar data according to the transformation tensor
            transformCoupleField(scalargpuReceiveBuf_, cmpt);
        }

        // Multiply the field by coefficients and add into the result
        coupledFvPatchgpuField<Type>::updateInterfaceMatrix(result, coeffs, scalargpuReceiveBuf_, !add);
    }
    else
    {
        scalargpuReceiveBuf_.setSize(this->size());
        procPatch_.compressedReceive<scalar>(commsType, scalargpuReceiveBuf_);

        if (!std::is_arithmetic<Type>::value)
        {
            // Transform non-scalar data according to the transformation tensor
            transformCoupleField(scalargpuReceiveBuf_, cmpt);
        }

        // Multiply the field by coefficients and add into the result
        coupledFvPatchgpuField<Type>::updateInterfaceMatrix(result, coeffs, scalargpuReceiveBuf_, !add);
    }

    const_cast<processorFvPatchgpuField<Type>&>(*this).updatedMatrix() = true;
}

template<class Type>
void Foam::processorFvPatchgpuField<Type>::initInterfaceMatrixUpdate
(
    gpuField<Type>&,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const gpuField<Type>& psiInternal,
    const scalargpuField&,
    const Pstream::commsTypes commsType
) const
{
    gpuSendBuf_.setSize(this->patch().size());

    const labelgpuList& faceCells = lduAddr.patchAddr(patchId);
	
    thrust::copy(thrust::make_permutation_iterator(psiInternal.begin(),faceCells.begin()),
                 thrust::make_permutation_iterator(psiInternal.begin(),faceCells.end()),
                 gpuSendBuf_.begin());

    if
    (
        commsType == Pstream::commsTypes::nonBlocking
     && !Pstream::floatTransfer
    )
    {
        // Fast path.
        if (debug && !this->ready())
        {
            FatalErrorInFunction
                << "On patch " << procPatch_.name()
                << " outstanding request."
                << abort(FatalError);
        }
		
        std::streamsize nBytes = gpuSendBuf_.byteSize();

        Type* receive;
        const Type* send;

        if(Pstream::gpuDirectTransfer)
        {
            gpuField<Type> gpuReceiveBuf_(gpuSendBuf_.size());

            send = gpuSendBuf_.data();
            receive = gpuReceiveBuf_.data();
        }
        else
        {
            sendBuf_.setSize(gpuSendBuf_.size());
            receiveBuf_.setSize(sendBuf_.size());
            thrust::copy
            (
                gpuSendBuf_.begin(),
                gpuSendBuf_.end(),
                sendBuf_.begin()
            );

            send = sendBuf_.begin();
            receive = receiveBuf_.begin();
        }


        receiveBuf_.setSize(sendBuf_.size());
        outstandingRecvRequest_ = UPstream::nRequests();
        IPstream::read
        (
            Pstream::commsTypes::nonBlocking,
            procPatch_.neighbProcNo(),
            reinterpret_cast<char*>(receive),
            nBytes,
            procPatch_.tag(),
            procPatch_.comm()
        );

        outstandingSendRequest_ = UPstream::nRequests();
        OPstream::write
        (
            Pstream::commsTypes::nonBlocking,
            procPatch_.neighbProcNo(),
            reinterpret_cast<const char*>(send),
            nBytes,
            procPatch_.tag(),
            procPatch_.comm()
        );
    }
    else
    {
        procPatch_.compressedSend(commsType, gpuSendBuf_);
    }

    const_cast<processorFvPatchgpuField<Type>&>(*this).updatedMatrix() = false;
}


template<class Type>
void Foam::processorFvPatchgpuField<Type>::updateInterfaceMatrix
(
    gpuField<Type>& result,
    const bool add,
    const gpulduAddressing& lduAddr,
    const label patchId,
    const gpuField<Type>&,
    const scalargpuField& coeffs,
    const Pstream::commsTypes commsType
) const
{
    if (this->updatedMatrix())
    {
        return;
    }

    const labelgpuList& faceCells = lduAddr.patchAddr(patchId);

    if
    (
        commsType == Pstream::commsTypes::nonBlocking
     && !Pstream::floatTransfer
    )
    {
        // Fast path.
        if
        (
            outstandingRecvRequest_ >= 0
         && outstandingRecvRequest_ < Pstream::nRequests()
        )
        {
            UPstream::waitRequest(outstandingRecvRequest_);
        }
        // Recv finished so assume sending finished as well.
        outstandingSendRequest_ = -1;
        outstandingRecvRequest_ = -1;

        // Consume straight from receiveBuf_
        if( ! Pstream::gpuDirectTransfer)
        {
            gpuReceiveBuf_ = receiveBuf_;
        }

        if (!std::is_arithmetic<Type>::value)
        {
        // Transform according to the transformation tensor
        transformCoupleField(gpuReceiveBuf_);
        }

        // Multiply the field by coefficients and add into the result
        this->addToInternalField(result, !add, patchId, lduAddr, coeffs, gpuReceiveBuf_);
    }
    else
    {
        gpuField<Type> pnf
        (
            procPatch_.compressedReceive<Type>(commsType, this->size())()
        );

        // Transform according to the transformation tensor
        transformCoupleField(pnf);

        // Multiply the field by coefficients and add into the result
        this->addToInternalField(result, !add, patchId, lduAddr, coeffs, pnf);
    }

    const_cast<processorFvPatchgpuField<Type>&>(*this).updatedMatrix() = true;
}


template<class Type>
bool Foam::processorFvPatchgpuField<Type>::ready() const
{
    if
    (
        outstandingSendRequest_ >= 0
     && outstandingSendRequest_ < Pstream::nRequests()
    )
    {
        bool finished = UPstream::finishedRequest(outstandingSendRequest_);
        if (!finished)
        {
            return false;
        }
    }
    outstandingSendRequest_ = -1;

    if
    (
        outstandingRecvRequest_ >= 0
     && outstandingRecvRequest_ < Pstream::nRequests()
    )
    {
        bool finished = UPstream::finishedRequest(outstandingRecvRequest_);
        if (!finished)
        {
            return false;
        }
    }
    outstandingRecvRequest_ = -1;

    return true;
}


// ************************************************************************* //
