/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2016-2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpufvScalarMatrix.H"
#include "extrapolatedCalculatedFvPatchgpuFields.H"
#include "profiling.H"
#include "PrecisionAdaptor.H"
#include "jumpCyclicFvPatchgpuField.H"
#include "cyclicPolyPatch.H"
#include "cyclicAMIPolyPatch.H"

// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

template<>
void Foam::gpufvMatrix<Foam::scalar>::setComponentReference
(
    const label patchi,
    const label facei,
    const direction,
    const scalar value
)
{
    if (psi_.needReference())
    {
        if (Pstream::master())
        {
            scalar delta = gpuDiag().get(psi_.mesh().boundary()[patchi].gpuFaceCells().get(facei));
            internalCoeffs_[patchi].set(facei, internalCoeffs_[patchi].get(facei)+delta);
            boundaryCoeffs_[patchi].set(facei,boundaryCoeffs_[patchi].get(facei)+delta*value); 
        }
    }
}


template<>
Foam::autoPtr<Foam::gpufvMatrix<Foam::scalar>::fvSolver>
Foam::gpufvMatrix<Foam::scalar>::solver
(
    const dictionary& solverControls
)
{
    word regionName;
    if (psi_.mesh().name() != polyMesh::defaultRegion)
    {
        regionName = psi_.mesh().name() + "::";
    }
    addProfiling(solve, "gpufvMatrix::solve." + regionName + psi_.name());

    if (debug)
    {
        Info.masterStream(this->mesh().comm())
            << "fvMatrix<scalar>::solver(const dictionary& solverControls) : "
               "solver for fvMatrix<scalar>"
            << endl;
    }

    scalargpuField saveDiag(gpuDiag());
    addBoundaryDiag(gpuDiag(), 0);

    lduInterfacegpuFieldPtrsList interfaces =
        psi_.boundaryField().scalarInterfaces();


    autoPtr<gpufvMatrix<scalar>::fvSolver> solverPtr
    (
        new gpufvMatrix<scalar>::fvSolver
        (
            *this,
            gpulduMatrix::solver::New
            (
                psi_.name(),
                *this,
                boundaryCoeffs_,
                internalCoeffs_,
                interfaces,
                solverControls
            )
        )
    );

    gpuDiag() = saveDiag;

    return solverPtr;
}


template<>
Foam::solverPerformance Foam::gpufvMatrix<Foam::scalar>::fvSolver::solve
(
    const dictionary& solverControls
)
{
    const int logLevel =
        solverControls.getOrDefault<int>
        (
            "log",
            solverPerformance::debug
        );

    auto& psi =
        const_cast<GeometricgpuField<scalar, fvPatchgpuField, gpuvolMesh>&>
        (
            fvMat_.psi()
        );

    scalargpuField saveDiag(fvMat_.gpuDiag());
    fvMat_.addBoundaryDiag(fvMat_.gpuDiag(), 0);

    scalargpuField totalSource(fvMat_.source());
    fvMat_.addBoundarySource(totalSource, false);

    // Assign new solver controls
    solver_->read(solverControls);

    solverPerformance solverPerf = solver_->solve
    (
        psi.primitiveFieldRef(),
        totalSource
    );

    if (logLevel)
    {
        solverPerf.print(Info.masterStream(fvMat_.mesh().comm()));
    }

    fvMat_.gpuDiag() = saveDiag;

    psi.correctBoundaryConditions();

    psi.mesh().hostmesh().setSolverPerformance(psi.name(), solverPerf);

    return solverPerf;
}


template<>
Foam::solverPerformance Foam::gpufvMatrix<Foam::scalar>::solveSegregated
(
    const dictionary& solverControls
)
{
    if (debug)
    {
        Info.masterStream(this->mesh().comm())
            << "fvMatrix<scalar>::solveSegregated"
               "(const dictionary& solverControls) : "
               "solving fvMatrix<scalar>"
            << endl;
    }

    const int logLevel =
        solverControls.getOrDefault<int>
        (
            "log",
            solverPerformance::debug
        );

    scalargpuField saveLower;
    scalargpuField saveUpper;

    if (useImplicit_)
    {
        createOrUpdateLduPrimitiveAssembly();

        if (psi_.mesh().hostmesh().fluxRequired(psi_.name()))
        {
            // Save lower/upper for flux calculation
            if (asymmetric())
            {
                saveLower = gpuLower();
            }
            saveUpper = gpuUpper();
        }

        setLduMesh(*lduMeshPtr());
        transferFvMatrixCoeffs();
        setBounAndInterCoeffs();
        direction cmpt = 0;
        manipulateMatrix(cmpt);
    }

    scalargpuField saveDiag(gpuDiag());
    addBoundaryDiag(gpuDiag(), 0);

    scalargpuField totalSource(source_);
    addBoundarySource(totalSource, false);

    lduInterfacegpuFieldPtrsList interfaces;
    PtrDynList<lduInterfacegpuField> newInterfaces;
    if (!useImplicit_)
    {
        interfaces = this->psi(0).boundaryField().scalarInterfaces();
    }
    else
    {
        setInterfaces(interfaces, newInterfaces);
    }

    tmp<scalargpuField> tpsi;
    if (!useImplicit_)
    {
        tpsi.ref
        (
            const_cast<GeometricgpuField<scalar, fvPatchgpuField, gpuvolMesh>&>
            (
                psi_
            ).primitiveFieldRef()
        );
    }
    else
    {
        tpsi = tmp<scalargpuField>::New(lduAddr().size(), Zero);
        scalargpuField& psi = tpsi.ref();

        for (label fieldi = 0; fieldi < nMatrices(); fieldi++)
        {
            const label cellOffset = lduMeshPtr()->cellOffsets()[fieldi];
            const auto& psiInternal = this->psi(fieldi).primitiveField();

            forAll(psiInternal, localCellI)
            {
                psi.set(cellOffset + localCellI, psiInternal.get(localCellI) );
            }
        }
    }
    scalargpuField& psi = tpsi.ref();

    // Solver call
    solverPerformance solverPerf = gpulduMatrix::solver::New
    (
        this->psi(0).name(),
        *this,
        boundaryCoeffs_,
        internalCoeffs_,
        interfaces,
        solverControls
    )->solve(psi, totalSource);

    if (useImplicit_)
    {
        for (label fieldi = 0; fieldi < nMatrices(); fieldi++)
        {
            auto& psiInternal =
                const_cast<GeometricgpuField<scalar, fvPatchgpuField, gpuvolMesh>&>
                (
                    this->psi(fieldi)
                ).primitiveFieldRef();

            const label cellOffset = lduMeshPtr()->cellOffsets()[fieldi];

            forAll(psiInternal, localCellI)
            {
                psiInternal.set(localCellI, psi.get(localCellI + cellOffset));
            }
        }
    }

    if (logLevel)
    {
        solverPerf.print(Info.masterStream(mesh().comm()));
    }

    gpuDiag() = saveDiag;

    if (useImplicit_)
    {
        if (psi_.mesh().hostmesh().fluxRequired(psi_.name()))
        {
            // Restore lower/upper
            if (asymmetric())
            {
                gpuLower().setSize(saveLower.size());
                gpuLower() = saveLower;
            }

            gpuUpper().setSize(saveUpper.size());
            gpuUpper() = saveUpper;
        }
        // Set the original lduMesh
        setLduMesh(psi_.mesh());
    }

    for (label fieldi = 0; fieldi < nMatrices(); fieldi++)
    {
        auto& localPsi =
            const_cast<GeometricgpuField<scalar, fvPatchgpuField, gpuvolMesh>&>
            (
                this->psi(fieldi)
            );

        localPsi.correctBoundaryConditions();
        localPsi.mesh().hostmesh().setSolverPerformance(localPsi.name(), solverPerf);
    }

    return solverPerf;
}

namespace Foam
{

struct fvScalarMatrixResidualFunctor
{
    template<class Tuple>
    __host__ __device__ 
    scalar operator()(const scalar& source, const Tuple& t)
    {
        return source - thrust::get<0>(t)*thrust::get<1>(t);
    }
};

}

template<>
Foam::tmp<Foam::scalargpuField> Foam::gpufvMatrix<Foam::scalar>::residual() const
{
    scalargpuField boundaryDiag(psi_.size(), Zero);
    addBoundaryDiag(boundaryDiag, 0);

    thrust::transform
    (
        source_.begin(),
        source_.end(),
        thrust::make_zip_iterator(thrust::make_tuple
        (
            boundaryDiag.begin(),
            psi_.primitiveField().begin()//internalField().begin()
        )),
        boundaryDiag.begin(),
        fvScalarMatrixResidualFunctor()
    );

    tmp<scalargpuField> tres_s(new scalargpuField(psi_.size()));

    gpulduMatrix::residual
    (
        tres_s.ref(),
        psi_.internalField(),
        boundaryDiag,
        boundaryCoeffs_,
        psi_.boundaryField().scalarInterfaces(),
        0
    );

    addBoundarySource(tres_s.ref());

    return tres_s;
}


template<>
Foam::tmp<Foam::volScalargpuField> Foam::gpufvMatrix<Foam::scalar>::H() const
{
    tmp<volScalargpuField> tHphi
    (
        new volScalargpuField
        (
            IOobject
            (
                "H("+psi_.name()+')',
                psi_.instance(),
                psi_.mesh().hostmesh(),
                IOobject::NO_READ,
                IOobject::NO_WRITE
            ),
            psi_.mesh(),
            dimensions_/dimVol,
            extrapolatedCalculatedFvPatchScalargpuField::typeName
        )
    );
    volScalargpuField& Hphi = tHphi.ref();

    Hphi.primitiveFieldRef() = (gpulduMatrix::H(psi_.primitiveField()) + source_);
    addBoundarySource(Hphi.primitiveFieldRef());

    Hphi.primitiveFieldRef() /= psi_.mesh().V();
    Hphi.correctBoundaryConditions();

    return tHphi;
}


template<>
Foam::tmp<Foam::volScalargpuField> Foam::gpufvMatrix<Foam::scalar>::H1() const
{
    tmp<volScalargpuField> tH1
    (
        new volScalargpuField
        (
            IOobject
            (
                "H(1)",
                psi_.instance(),
                psi_.mesh().hostmesh(),
                IOobject::NO_READ,
                IOobject::NO_WRITE
            ),
            psi_.mesh(),
            dimensions_/(dimVol*psi_.dimensions()),
            extrapolatedCalculatedFvPatchScalargpuField::typeName
        )
    );
    volScalargpuField& H1_ = tH1.ref();

    H1_.primitiveFieldRef() = gpulduMatrix::H1();
    //addBoundarySource(Hphi.primitiveField());

    H1_.primitiveFieldRef() /= psi_.mesh().V();
    H1_.correctBoundaryConditions();

    return tH1;
}


// ************************************************************************* //
