/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2014 OpenFOAM Foundation
    Copyright (C) 2016-2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "GsmoothSolver.H"
#include "profiling.H"
#include "PrecisionAdaptor.H"

// * * * * * * * * * * * * * * Static Data Members * * * * * * * * * * * * * //

namespace Foam
{
    defineTypeNameAndDebug(GsmoothSolver, 0);

    gpulduMatrix::solver::addsymMatrixConstructorToTable<GsmoothSolver>
        addGsmoothSolverSymMatrixConstructorToTable_;

    gpulduMatrix::solver::addasymMatrixConstructorToTable<GsmoothSolver>
        addGsmoothSolverAsymMatrixConstructorToTable_;
}


// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

Foam::GsmoothSolver::GsmoothSolver
(
    const word& fieldName,
    const gpulduMatrix& matrix,
    const FieldField<gpuField, scalar>& interfaceBouCoeffs,
    const FieldField<gpuField, scalar>& interfaceIntCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces,
    const dictionary& solverControls
)
:
    gpulduMatrix::solver
    (
        fieldName,
        matrix,
        interfaceBouCoeffs,
        interfaceIntCoeffs,
        interfaces,
        solverControls
    )
{
    readControls();
}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

void Foam::GsmoothSolver::readControls()
{
    gpulduMatrix::solver::readControls();
    nSweeps_ = controlDict_.getOrDefault<label>("nSweeps", 1);
}

Foam::solverPerformance Foam::GsmoothSolver::solve
(
    scalargpuField& psi,
    const scalargpuField& source,
    const direction cmpt
) const
{
    solverPerformance solverPerf(typeName, fieldName_);
 
    if (nSweeps_ < 0)
    {
        addProfiling(solve, "gpulduMatrix::smoother." + fieldName_);
 
        autoPtr<gpulduMatrix::smoother> smootherPtr = gpulduMatrix::smoother::New
        (
            fieldName_,
            matrix_,
            interfaceBouCoeffs_,
            interfaceIntCoeffs_,
            interfaces_,
            controlDict_
        );
 
        smootherPtr->smooth
        (
            psi,
            source,
            cmpt,
            -nSweeps_
        );
 
        solverPerf.nIterations() -= nSweeps_;
    }
    else
    {
        solveScalar normFactor = 0;
        scalargpuField residual(psi.size());
 
        {
            scalargpuField Apsi(psi.size());
            scalargpuField temp(psi.size());
 
            matrix_.Amul(Apsi, psi, interfaceBouCoeffs_, interfaces_, cmpt);
 
            normFactor = this->normFactor(psi, source, Apsi, temp);
 
            thrust::transform
            (
                source.begin(),
                source.end(),
                Apsi.begin(),
                residual.begin(),
                subtractOperatorFunctor<scalar,scalar,scalar>()
            );
 
            solverPerf.initialResidual() =
                gSumMag(residual, matrix().mesh().comm())/normFactor;
            solverPerf.finalResidual() = solverPerf.initialResidual();
        }
 
        if ((log_ >= 2) || (gpulduMatrix::debug >= 2))
        {
            Info.masterStream(matrix().mesh().comm())
                << "   Normalisation factor = " << normFactor << endl;
        }
 
        if
        (
            minIter_ > 0
         || !solverPerf.checkConvergence(tolerance_, relTol_, log_)
        )
        {
            addProfiling(solve, "gpulduMatrix::smoother." + fieldName_);

            autoPtr<gpulduMatrix::smoother> smootherPtr = gpulduMatrix::smoother::New
            (
                fieldName_,
                matrix_,
                interfaceBouCoeffs_,
                interfaceIntCoeffs_,
                interfaces_,
                controlDict_
            );

            do
            {
                smootherPtr->smooth
                (
                    psi,
                    source,
                    cmpt,
                    nSweeps_
                );
 
                residual =
                    matrix_.residual
                    (
                        psi,
                        source,
                        interfaceBouCoeffs_,
                        interfaces_,
                        cmpt
                    );
 
                solverPerf.finalResidual() =
                    gSumMag(residual, matrix().mesh().comm())/normFactor;
            } while
            (
                (
                    (solverPerf.nIterations() += nSweeps_) < maxIter_
                && !solverPerf.checkConvergence(tolerance_, relTol_, log_)
                )
             || solverPerf.nIterations() < minIter_
            );
        }
    }
 
    return solverPerf;
}

// ************************************************************************* //
