/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2017 OpenFOAM Foundation
    Copyright (C) 2016-2021 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpulduMatrix.H"
#include "GdiagonalSolver.H"
#include "PrecisionAdaptor.H"

#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>

// * * * * * * * * * * * * * * Static Data Members * * * * * * * * * * * * * //

namespace Foam
{
    defineRunTimeSelectionTable(gpulduMatrix::solver, symMatrix);
    defineRunTimeSelectionTable(gpulduMatrix::solver, asymMatrix);
}


// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

Foam::autoPtr<Foam::gpulduMatrix::solver> Foam::gpulduMatrix::solver::New
(
    const word& fieldName,
    const gpulduMatrix& matrix,
    const FieldField<gpuField, scalar>& interfaceBouCoeffs,
    const FieldField<gpuField, scalar>& interfaceIntCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces,
    const dictionary& solverControls
)
{
    const word name(solverControls.get<word>("solver"));

    if (matrix.diagonal())
    {
        return autoPtr<gpulduMatrix::solver>
        (
            new GdiagonalSolver
            (
                fieldName,
                matrix,
                interfaceBouCoeffs,
                interfaceIntCoeffs,
                interfaces,
                solverControls
            )
        );
    }
    else if (matrix.symmetric())
    {
        auto* ctorPtr = symMatrixConstructorTable(name);

        if (!ctorPtr)
        {
            FatalIOErrorInLookup
            (
                solverControls,
                "symmetric matrix solver",
                name,
                *symMatrixConstructorTablePtr_
            ) << exit(FatalIOError);
        }

        return autoPtr<gpulduMatrix::solver>
        (
            ctorPtr
            (
                fieldName,
                matrix,
                interfaceBouCoeffs,
                interfaceIntCoeffs,
                interfaces,
                solverControls
            )
        );
    }
    else if (matrix.asymmetric())
    {
        auto* ctorPtr = asymMatrixConstructorTable(name);

        if (!ctorPtr)
        {
            FatalIOErrorInLookup
            (
                solverControls,
                "asymmetric matrix solver",
                name,
                *asymMatrixConstructorTablePtr_
            ) << exit(FatalIOError);
        }

        return autoPtr<gpulduMatrix::solver>
        (
            ctorPtr
            (
                fieldName,
                matrix,
                interfaceBouCoeffs,
                interfaceIntCoeffs,
                interfaces,
                solverControls
            )
        );
    }

    FatalIOErrorInFunction(solverControls)
        << "cannot solve incomplete matrix, "
        "no diagonal or off-diagonal coefficient"
        << exit(FatalIOError);

    return nullptr;
}


// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

Foam::gpulduMatrix::solver::solver
(
    const word& fieldName,
    const gpulduMatrix& matrix,
    const FieldField<gpuField, scalar>& interfaceBouCoeffs,
    const FieldField<gpuField, scalar>& interfaceIntCoeffs,
    const lduInterfacegpuFieldPtrsList& interfaces,
    const dictionary& solverControls
)
:
    fieldName_(fieldName),
    matrix_(matrix),
    interfaceBouCoeffs_(interfaceBouCoeffs),
    interfaceIntCoeffs_(interfaceIntCoeffs),
    interfaces_(interfaces),
    controlDict_(solverControls),
    profiling_("gpulduMatrix::solver." + fieldName)
{
    readControls();
}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

void Foam::gpulduMatrix::solver::readControls()
{
    log_ = controlDict_.getOrDefault<int>("log", 1);
    minIter_ = controlDict_.getOrDefault<label>("minIter", 0);
    maxIter_ = controlDict_.getOrDefault<label>("maxIter", defaultMaxIter_);
    tolerance_ = controlDict_.getOrDefault<scalar>("tolerance", 1e-6);
    relTol_ = controlDict_.getOrDefault<scalar>("relTol", 0);
}


void Foam::gpulduMatrix::solver::read(const dictionary& solverControls)
{
    controlDict_ = solverControls;
    readControls();
}


Foam::solverPerformance Foam::gpulduMatrix::solver::scalarSolve
(
    scalargpuField& psi,
    const scalargpuField& source,
    const direction cmpt
) const
{
/*
    PrecisionAdaptor<scalar, solveScalar> tpsi_s(psi);
    return solve
    (
        tpsi_s.ref(),
        ConstPrecisionAdaptor<scalar, solveScalar>(source)(),
        cmpt
    );
*/
    return solve
    (
        psi,
        source,
        cmpt
    );
}

namespace Foam 
{

struct normFactorFunctor: public thrust::unary_function<label, double> {
     const scalar * Apsi;
     const scalar * source;
     const scalar * tmpField;
     const scalar average;

     normFactorFunctor (
          const scalar * _Apsi,
          const scalar * _source,
          const scalar * _tmpField,
          const scalar _average
    ): Apsi(_Apsi), source(_source), tmpField(_tmpField), average(_average) {}

    __host__ __device__
    scalar operator()(const label id)
    {
        scalar tmpVal = average * tmpField[id];
        return mag(Apsi[id] - tmpVal) + mag(source[id] - tmpVal);
    }
};

}

Foam::solveScalarField::cmptType Foam::gpulduMatrix::solver::normFactor
(
    const scalargpuField& psi,
    const scalargpuField& source,
    const scalargpuField& Apsi,
    scalargpuField& tmpField
) const
{
    // --- Calculate A dot reference value of psi
    matrix_.sumA(tmpField, interfaceBouCoeffs_, interfaces_);

    scalar average = gAverage(psi, matrix_.mesh().comm());

    normFactorFunctor kernel(
        Apsi.data(), source.data(), tmpField.data(), average
    );

    scalar factor = thrust::reduce (
        thrust::make_transform_iterator (
            thrust::make_counting_iterator(0), kernel
        ),
        thrust::make_transform_iterator(
            thrust::make_counting_iterator(psi.size()), kernel
        )
    );

    reduce(factor, sumOp<scalar>(), Pstream::msgType(), matrix_.lduMesh_.get().comm());
    return factor + solverPerformance::small_;

    // At convergence this simpler method is equivalent to the above
    //return 2*gSumMag(source) + solverPerformance::small_;
}

// ************************************************************************* //
