/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2019-2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "processorgpuLduInterface.H"
#include "IPstream.H"
#include "OPstream.H"

#include "DeviceMemory.H"
#include <thrust/iterator/counting_iterator.h>

// * * * * * * * * * * * * * * * Member Functions * * *  * * * * * * * * * * //

namespace Foam
{

template<class To, class From, bool compress>
struct compressFunctor
{
    To* to;
    const From* from;
    const scalar* slast;
    const label nCmpts;

    compressFunctor
    (
        To* _to,
        const From* _from,
        const scalar* _slast,
        label _nCmpts
    ):
        to(_to),
        from(_from),
        slast(_slast),
        nCmpts(_nCmpts)
    {}

    __host__ __device__
    void operator()(const label& i)
    {
        if(compress)
            to[i] = (To) (from[i] - slast[i%nCmpts]);
        else
            to[i] = (To) (from[i] + slast[i%nCmpts]);
    }
};

}

template<class Type>
void Foam::processorgpuLduInterface::send
(
    const Pstream::commsTypes commsType,
    const gpuList<Type>& f
) const
{
    label nBytes = f.byteSize();

    if
    (
        commsType == Pstream::commsTypes::blocking
     || commsType == Pstream::commsTypes::scheduled
    )
    {
        const char* sendData;
        if(Pstream::gpuDirectTransfer)
        {
            sendData = reinterpret_cast<const char*>(f.data());
        }
        else
        {
            resizeBuf(sendBuf_, nBytes);
            copyDeviceToHost(sendBuf_.begin(), f.data(), nBytes);
            sendData = sendBuf_.begin();
        }

        OPstream::write
        (
            commsType,
            neighbProcNo(),
            sendData,
            nBytes,
            tag(),
            comm()
        );
    }
    else if (commsType == Pstream::commsTypes::nonBlocking)
    {
        char* receive;
        const char* send;

        if(Pstream::gpuDirectTransfer)
        {
            resizeBuf(gpuReceiveBuf_, nBytes);
            resizeBuf(gpuSendBuf_, nBytes);

            copyDeviceToDevice(gpuSendBuf_.data(), f.data(), nBytes);

            send = gpuSendBuf_.data();
            receive = gpuReceiveBuf_.data();
        }
        else
        {
            resizeBuf(receiveBuf_, nBytes);
            resizeBuf(sendBuf_, nBytes);

            copyDeviceToHost(sendBuf_.begin(), f.data(), nBytes);

            send = sendBuf_.begin();
            receive = receiveBuf_.begin();
        }

        IPstream::read
        (
            commsType,
            neighbProcNo(),
            receive,
            nBytes,
            tag(),
            comm()
        );

        OPstream::write
        (
            commsType,
            neighbProcNo(),
            send,
            nBytes,
            tag(),
            comm()
        );
    }
    else
    {
        FatalErrorInFunction
            << "Unsupported communications type " << int(commsType)
            << exit(FatalError);
    }
}


template<class Type>
void Foam::processorgpuLduInterface::receive
(
    const Pstream::commsTypes commsType,
    gpuList<Type>& f
) const
{
    if
    (
        commsType == Pstream::commsTypes::blocking
     || commsType == Pstream::commsTypes::scheduled
    )
    {
        char * read;
        if(Pstream::gpuDirectTransfer)
        {
            read = reinterpret_cast<char*>(f.data());
        }
        else
        {
            resizeBuf(receiveBuf_, f.byteSize());
            read = receiveBuf_.begin();
        }

        IPstream::read
        (
            commsType,
            neighbProcNo(),
            reinterpret_cast<char*>(read),
            f.byteSize(),
            tag(),
            comm()
        );

        if( ! Pstream::gpuDirectTransfer)
        {
            copyHostToDevice(f.data(), receiveBuf_.data(), f.byteSize());
        }
    }
    else if (commsType == Pstream::commsTypes::nonBlocking)
    {
        if(Pstream::gpuDirectTransfer)
        {
            copyDeviceToDevice(f.data(), gpuReceiveBuf_.data(), f.byteSize());
        }
        else
        {
            copyHostToDevice(f.data(), receiveBuf_.data(), f.byteSize());
        }
    }
    else
    {
        FatalErrorInFunction
            << "Unsupported communications type " << int(commsType)
            << exit(FatalError);
    }
}


template<class Type>
Foam::tmp<Foam::gpuField<Type>> Foam::processorgpuLduInterface::receive
(
    const Pstream::commsTypes commsType,
    const label size
) const
{
    auto tfld = tmp<gpuField<Type>>::New(size);
    receive(commsType, tfld.ref());
    return tfld;
}


template<class Type>
void Foam::processorgpuLduInterface::compressedSend
(
    const Pstream::commsTypes commsType,
    const gpuList<Type>& f
) const
{
    if (sizeof(scalar) != sizeof(float) && Pstream::floatTransfer && f.size())
    {
        static const label nCmpts = sizeof(Type)/sizeof(scalar);
        label nm1 = (f.size() - 1)*nCmpts;
        label nlast = sizeof(Type)/sizeof(float);
        label nFloats = nm1 + nlast;
        label nBytes = nFloats*sizeof(float);

        const scalar *sArray = reinterpret_cast<const scalar*>(f.data());
        const scalar *slast = &sArray[nm1];
        resizeBuf(sendBuf_, nBytes);
        float *fArray = reinterpret_cast<float*>(gpuSendBuf_.data());

        thrust::for_each
        (
            thrust::make_counting_iterator(0),
            thrust::make_counting_iterator(nm1),
            compressFunctor<float,scalar,true>
            (
                fArray,
                sArray,
                slast,
                nCmpts
             )
        );

        HIP_CALL(hipMemcpy(fArray+nm1, f.data() + (f.size() - 1), sizeof(Type), hipMemcpyDeviceToDevice));

        if
        (
            commsType == Pstream::commsTypes::blocking
         || commsType == Pstream::commsTypes::scheduled
        )
        {
            const char* sendData;
            if(Pstream::gpuDirectTransfer)
            {
                sendData = gpuSendBuf_.data();
            }
            else
            {
                resizeBuf(sendBuf_, nBytes);
                copyDeviceToHost(sendBuf_.begin(), gpuSendBuf_.data(), nBytes);
                sendData = sendBuf_.begin();
            }

            OPstream::write
            (
                commsType,
                neighbProcNo(),
                sendData,
                nBytes,
                tag(),
                comm()
            );
        }
        else if (commsType == Pstream::commsTypes::nonBlocking)
        {
            const char* sendData;
            char * readData;

            if(Pstream::gpuDirectTransfer)
            {
                resizeBuf(gpuReceiveBuf_, nBytes);
                readData = gpuReceiveBuf_.data();
                sendData = gpuSendBuf_.data();
            }
            else
            {
                resizeBuf(receiveBuf_, nBytes);
                resizeBuf(sendBuf_, nBytes);

                copyDeviceToHost(sendBuf_.begin(), gpuSendBuf_.data(), nBytes);

                sendData = sendBuf_.begin();
                readData = receiveBuf_.begin();
            }

            IPstream::read
            (
                commsType,
                neighbProcNo(),
                readData,
                nBytes,
                tag(),
                comm()
            );

            OPstream::write
            (
                commsType,
                neighbProcNo(),
                sendData,
                nBytes,
                tag(),
                comm()
            );
        }
        else
        {
            FatalErrorInFunction
                << "Unsupported communications type " << int(commsType)
                << exit(FatalError);
        }
    }
    else
    {
        this->send(commsType, f);
    }
}


template<class Type>
void Foam::processorgpuLduInterface::compressedReceive
(
    const Pstream::commsTypes commsType,
    gpuList<Type>& f
) const
{
    if (sizeof(scalar) != sizeof(float) && Pstream::floatTransfer && f.size())
    {
        static const label nCmpts = sizeof(Type)/sizeof(scalar);
        label nm1 = (f.size() - 1)*nCmpts;
        label nlast = sizeof(Type)/sizeof(float);
        label nFloats = nm1 + nlast;
        label nBytes = nFloats*sizeof(float);

        if
        (
            commsType == Pstream::commsTypes::blocking
         || commsType == Pstream::commsTypes::scheduled
        )
        {
            char* readData;
            if(Pstream::gpuDirectTransfer)
            {
                readData = gpuReceiveBuf_.data();
            }
            else
            {
                resizeBuf(receiveBuf_, nBytes);
                readData = receiveBuf_.begin();
            }

            IPstream::read
            (
                commsType,
                neighbProcNo(),
                readData,
                nBytes,
                tag(),
                comm()
            );

            if( ! Pstream::gpuDirectTransfer)
            {
                copyHostToDevice(gpuReceiveBuf_.data(), receiveBuf_.data(), f.byteSize());
            }
        }
        else if (commsType == Pstream::commsTypes::nonBlocking)
        {
            if( ! Pstream::gpuDirectTransfer)
            {
                copyHostToDevice(gpuReceiveBuf_.data(), receiveBuf_.data(), f.byteSize());
            }
        }
        else if (commsType != Pstream::commsTypes::nonBlocking)
        {
            FatalErrorInFunction
                << "Unsupported communications type " << int(commsType)
                << exit(FatalError);
        }

        const float *fArray =
            reinterpret_cast<const float*>(gpuReceiveBuf_.data());

        copyDeviceToDevice(f.data()+(f.size() - 1),fArray+nm1, sizeof(Type));

        scalar *sArray = reinterpret_cast<scalar*>(f.data());
        const scalar *slast = &sArray[nm1];

        thrust::for_each
        (
            thrust::make_counting_iterator(0),
            thrust::make_counting_iterator(nm1),
            compressFunctor<scalar,float,false>
            (
                sArray,
                fArray,
                slast,
                nCmpts
             )
        );
    }
    else
    {
        this->receive<Type>(commsType, f);
    }
}


template<class Type>
Foam::tmp<Foam::gpuField<Type>> Foam::processorgpuLduInterface::compressedReceive
(
    const Pstream::commsTypes commsType,
    const label size
) const
{
    auto tfld = tmp<gpuField<Type>>::New(size);
    compressedReceive(commsType, tfld.ref());
    return tfld;
}


// ************************************************************************* //
