/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2020 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpuLUscalarMatrix.H"
#include "gpulduMatrix.H"
#include "procgpuLduMatrix.H"
#include "procgpuLduInterface.H"
#include "cyclicgpuLduInterface.H"
#include "SubList.H"

// * * * * * * * * * * * * * * Static Data Members * * * * * * * * * * * * * //

namespace Foam
{
    defineTypeNameAndDebug(gpuLUscalarMatrix, 0);

    PageLockedBuffer<scalar> gpuLUscalarMatrix::dBuffer;
    PageLockedBuffer<scalar> gpuLUscalarMatrix::lBuffer;
    PageLockedBuffer<scalar> gpuLUscalarMatrix::uBuffer;
    PageLockedBuffer<label> gpuLUscalarMatrix::uAddrBuffer;
    PageLockedBuffer<label> gpuLUscalarMatrix::lAddrBuffer;
    DeviceStream gpuLUscalarMatrix::stream1;
    DeviceStream gpuLUscalarMatrix::stream2;
}


// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

Foam::gpuLUscalarMatrix::gpuLUscalarMatrix()
:
    comm_(Pstream::worldComm)
{}


Foam::gpuLUscalarMatrix::gpuLUscalarMatrix(const scalarSquareMatrix& matrix)
:
    scalarSquareMatrix(matrix),
    comm_(Pstream::worldComm),
    pivotIndices_(m())
{
    LUDecompose(*this, pivotIndices_);
}


Foam::gpuLUscalarMatrix::gpuLUscalarMatrix
(
    const gpulduMatrix& ldum,
    const FieldField<gpuField, scalar>& interfaceCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces
)
:
    comm_(ldum.mesh().comm())
{
    if (Pstream::parRun())
    {
        PtrList<procgpuLduMatrix> lduMatrices(Pstream::nProcs(comm_));

        label lduMatrixi = 0;

        lduMatrices.set
        (
            lduMatrixi++,
            new procgpuLduMatrix
            (
                ldum,
                interfaceCoeffs,
                interfaces
            )
        );

        if (Pstream::master(comm_))
        {
            for (const int slave : Pstream::subProcs(comm_))
            {
                lduMatrices.set
                (
                    lduMatrixi++,
                    new procgpuLduMatrix
                    (
                        IPstream
                        (
                            Pstream::commsTypes::scheduled,
                            slave,
                            0,          // bufSize
                            Pstream::msgType(),
                            comm_
                        )()
                    )
                );
            }
        }
        else
        {
            OPstream toMaster
            (
                Pstream::commsTypes::scheduled,
                Pstream::masterNo(),
                0,              // bufSize
                Pstream::msgType(),
                comm_
            );
            procgpuLduMatrix cldum
            (
                ldum,
                interfaceCoeffs,
                interfaces
            );
            toMaster<< cldum;

        }

        if (Pstream::master(comm_))
        {
            label nCells = 0;
            forAll(lduMatrices, i)
            {
                nCells += lduMatrices[i].size();
            }

            scalarSquareMatrix m(nCells, 0.0);
            transfer(m);
            convert(lduMatrices);
        }
    }
    else
    {
        label nCells = ldum.lduAddr().size();
        scalarSquareMatrix m(nCells, Zero);
        transfer(m);
        convert(ldum, interfaceCoeffs, interfaces);
    }

    if (Pstream::master(comm_))
    {
        if (debug)
        {
            const label numRows = m();
            const label numCols = n();

            Pout<< "gpuLUscalarMatrix : size:" << numRows << endl;
            for (label rowi = 0; rowi < numRows; ++rowi)
            {
                const scalar* row = operator[](rowi);

                Pout<< "cell:" << rowi << " diagCoeff:" << row[rowi] << endl;

                Pout<< "    connects to upper cells :";
                for (label coli = rowi+1; coli < numCols; ++coli)
                {
                    if (mag(row[coli]) > SMALL)
                    {
                        Pout<< ' ' << coli << " (coeff:" << row[coli] << ')';
                    }
                }
                Pout<< endl;
                Pout<< "    connects to lower cells :";
                for (label coli = 0; coli < rowi; ++coli)
                {
                    if (mag(row[coli]) > SMALL)
                    {
                        Pout<< ' ' << coli << " (coeff:" << row[coli] << ')';
                    }
                }
                Pout<< nl;
            }
            Pout<< nl;
        }

        pivotIndices_.setSize(m());
        LUDecompose(*this, pivotIndices_);
    }
}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

void Foam::gpuLUscalarMatrix::convert
(
    const gpulduMatrix& ldum,
    const FieldField<gpuField, scalar>& interfaceCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces
)
{
    //const label* __restrict__ uPtr = ldum.lduAddr().upperAddr().begin();
    //const label* __restrict__ lPtr = ldum.lduAddr().lowerAddr().begin();

    const scalargpuField& diag = ldum.gpuDiag();
    const scalargpuField& upper = ldum.gpuUpper();
    const scalargpuField& lower = ldum.gpuLower();
    const labelgpuList& uAddr = ldum.lduAddr().upperAddr();
    const labelgpuList& lAddr = ldum.lduAddr().lowerAddr();

    const label nCells = diag.size();
    const label nFaces = upper.size();

    Field<scalar>& diagPtr = dBuffer.buffer(nCells);
    Field<scalar>& upperPtr = uBuffer.buffer(nFaces);
    Field<scalar>& lowerPtr = lBuffer.buffer(nFaces);
    List<label>& uPtr = uAddrBuffer.buffer(nFaces);
    List<label>& lPtr = lAddrBuffer.buffer(nFaces);
    
    hipMemcpyAsync(diagPtr.data(), diag.data(), diag.byteSize(), hipMemcpyDeviceToHost, stream1());
    hipMemcpyAsync(uPtr.data(), uAddr.data(), uAddr.byteSize(), hipMemcpyDeviceToHost, stream1());
    hipMemcpyAsync(upperPtr.data(), upper.data(), upper.byteSize(), hipMemcpyDeviceToHost, stream2());
    hipMemcpyAsync(lowerPtr.data(), lower.data(), lower.byteSize(), hipMemcpyDeviceToHost, stream2());
    hipMemcpyAsync(lPtr.data(), lAddr.data(), lAddr.byteSize(), hipMemcpyDeviceToHost, stream1());

    stream1.synchronize();

    for (label cell=0; cell<nCells; cell++)
    {
        operator[](cell)[cell] = diagPtr[cell];
    }

    stream2.synchronize();

    for (label face=0; face<nFaces; face++)
    {
        label uCell = uPtr[face];
        label lCell = lPtr[face];

        operator[](uCell)[lCell] = lowerPtr[face];
        operator[](lCell)[uCell] = upperPtr[face];
    }

    forAll(interfaces, inti)
    {
        if (interfaces.set(inti))
        {
            const gpulduInterface& interface = interfaces[inti].interface();

            // Assume any interfaces are cyclic ones

            //const label* __restrict__ lPtr = interface.faceCells().begin();
            const labelgpuList& lAddrPtr = interface.gpuFaceCells();
            labelList lPtr(lAddrPtr.size());//(lAddrPtr); 
            thrust::copy(lAddrPtr.begin(),lAddrPtr.end(),lPtr.begin());

            const cyclicgpuLduInterface& cycInterface =
                refCast<const cyclicgpuLduInterface>(interface);
            label nbrInt = cycInterface.neighbPatchID();
            //const label* __restrict__ uPtr =
            //    interfaces[nbrInt].interface().faceCells().begin();
            const labelgpuList& uAddrPtr = interfaces[nbrInt].interface().gpuFaceCells();
            labelList uPtr(uAddrPtr.size());//(uAddrPtr); 
            thrust::copy(uAddrPtr.begin(),uAddrPtr.end(),uPtr.begin());

            const scalargpuField& coeffs = interfaceCoeffs[nbrInt];
            const scalarField nbrUpperLowerPtr(coeffs);

            label inFaces = interface.gpuFaceCells().size();

            for (label face=0; face<inFaces; face++)
            {
                label uCell = lPtr[face];
                label lCell = uPtr[face];

                operator[](uCell)[lCell] -= nbrUpperLowerPtr[face];
            }
        }
    }
}


void Foam::gpuLUscalarMatrix::convert
(
    const PtrList<procgpuLduMatrix>& lduMatrices
)
{
    procOffsets_.setSize(lduMatrices.size() + 1);
    procOffsets_[0] = 0;

    forAll(lduMatrices, ldumi)
    {
        procOffsets_[ldumi+1] = procOffsets_[ldumi] + lduMatrices[ldumi].size();
    }

    forAll(lduMatrices, ldumi)
    {
        const procgpuLduMatrix& lduMatrixi = lduMatrices[ldumi];
        label offset = procOffsets_[ldumi];

        const label* __restrict__ uPtr = lduMatrixi.upperAddr_.begin();
        const label* __restrict__ lPtr = lduMatrixi.lowerAddr_.begin();

        const scalar* __restrict__ diagPtr = lduMatrixi.diag_.begin();
        const scalar* __restrict__ upperPtr = lduMatrixi.upper_.begin();
        const scalar* __restrict__ lowerPtr = lduMatrixi.lower_.begin();

        const label nCells = lduMatrixi.size();
        const label nFaces = lduMatrixi.upper_.size();

        for (label cell=0; cell<nCells; cell++)
        {
            label globalCell = cell + offset;
            operator[](globalCell)[globalCell] = diagPtr[cell];
        }

        for (label face=0; face<nFaces; face++)
        {
            label uCell = uPtr[face] + offset;
            label lCell = lPtr[face] + offset;

            operator[](uCell)[lCell] = lowerPtr[face];
            operator[](lCell)[uCell] = upperPtr[face];
        }

        const PtrList<procgpuLduInterface>& interfaces =
            lduMatrixi.interfaces_;

        forAll(interfaces, inti)
        {
            const procgpuLduInterface& interface = interfaces[inti];

            if (interface.myProcNo_ == interface.neighbProcNo_)
            {
                const label* __restrict__ ulPtr = interface.faceCells_.begin();

                const scalar* __restrict__ upperLowerPtr =
                    interface.coeffs_.begin();

                label inFaces = interface.faceCells_.size()/2;

                for (label face=0; face<inFaces; face++)
                {
                    label uCell = ulPtr[face] + offset;
                    label lCell = ulPtr[face + inFaces] + offset;

                    operator[](uCell)[lCell] -= upperLowerPtr[face + inFaces];
                    operator[](lCell)[uCell] -= upperLowerPtr[face];
                }
            }
            else if (interface.myProcNo_ < interface.neighbProcNo_)
            {
                // Interface to neighbour proc. Find on neighbour proc the
                // corresponding interface. The problem is that there can
                // be multiple interfaces between two processors (from
                // processorCyclics) so also compare the communication tag

                const PtrList<procgpuLduInterface>& neiInterfaces =
                    lduMatrices[interface.neighbProcNo_].interfaces_;

                label neiInterfacei = -1;

                forAll(neiInterfaces, ninti)
                {
                    if
                    (
                        (
                            neiInterfaces[ninti].neighbProcNo_
                         == interface.myProcNo_
                        )
                     && (neiInterfaces[ninti].tag_ ==  interface.tag_)
                    )
                    {
                        neiInterfacei = ninti;
                        break;
                    }
                }

                if (neiInterfacei == -1)
                {
                    FatalErrorInFunction << exit(FatalError);
                }

                const procgpuLduInterface& neiInterface =
                    neiInterfaces[neiInterfacei];

                const label* __restrict__ uPtr = interface.faceCells_.begin();
                const label* __restrict__ lPtr =
                    neiInterface.faceCells_.begin();

                const scalar* __restrict__ upperPtr = interface.coeffs_.begin();
                const scalar* __restrict__ lowerPtr =
                    neiInterface.coeffs_.begin();

                label inFaces = interface.faceCells_.size();
                label neiOffset = procOffsets_[interface.neighbProcNo_];

                for (label face=0; face<inFaces; face++)
                {
                    label uCell = uPtr[face] + offset;
                    label lCell = lPtr[face] + neiOffset;

                    operator[](uCell)[lCell] -= lowerPtr[face];
                    operator[](lCell)[uCell] -= upperPtr[face];
                }
            }
        }
    }
}


void Foam::gpuLUscalarMatrix::printDiagonalDominance() const
{
    for (label i=0; i<m(); i++)
    {
        scalar sum = 0.0;
        for (label j=0; j<m(); j++)
        {
            if (i != j)
            {
                sum += operator[](i)[j];
            }
        }
        Info<< mag(sum)/mag(operator[](i)[i]) << endl;
    }
}


void Foam::gpuLUscalarMatrix::decompose(const scalarSquareMatrix& M)
{
    scalarSquareMatrix::operator=(M);
    pivotIndices_.setSize(m());
    LUDecompose(*this, pivotIndices_);
}


void Foam::gpuLUscalarMatrix::inv(scalarSquareMatrix& M) const
{
    scalarField source(m());

    for (label j=0; j<m(); j++)
    {
        source = Zero;
        source[j] = 1;
        LUBacksubstitute(*this, pivotIndices_, source);
        for (label i=0; i<m(); i++)
        {
            M(i, j) = source[i];
        }
    }
}


// ************************************************************************* //
