#include "AINVPreconditioner.H"
#include "AINVPreconditionerF.H"
#include "lduMatrixSolutionCache.H"

namespace Foam
{
    defineTypeNameAndDebug(AINVPreconditioner, 0);

    Foam::gpulduMatrix::preconditioner::
        addsymMatrixConstructorToTable<Foam::AINVPreconditioner>
        addAINVPreconditionerSymMatrixConstructorToTable_;

	gpulduMatrix::preconditioner::
        addasymMatrixConstructorToTable<AINVPreconditioner>
        addAINVPreconditionerAsymMatrixConstructorToTable_;
}

Foam::AINVPreconditioner::AINVPreconditioner
(
    const gpulduMatrix::solver& sol,
    const dictionary&
)
:
    gpulduMatrix::preconditioner(sol),
    rD
    (
        lduMatrixSolutionCache::first(sol.matrix().gpuDiag().size()),
        sol.matrix().gpuDiag().size()
    )
{

    const scalargpuField& Diag = solver_.matrix().gpuDiag();

    thrust::transform
    (
        Diag.begin(),
        Diag.end(),
        rD.begin(),
        divideOperatorSFFunctor<scalar,scalar,scalar>(1.0)
    );
}

Foam::AINVPreconditioner::~AINVPreconditioner()
{
}

template<bool normalMult>
void Foam::AINVPreconditioner::preconditionImpl
(
    scalargpuField& w,
    const scalargpuField& r,
    const direction d
) const
{
    bool fastPath = lduMatrixSolutionCache::favourSpeed;

    const labelgpuList& l = fastPath?
                            solver_.matrix().lduAddr().gpuOwnerSortAddr():
                            solver_.matrix().lduAddr().lowerAddr();
    const labelgpuList& u = solver_.matrix().lduAddr().upperAddr();
    const labelgpuList& ownStart = solver_.matrix().lduAddr().ownerStartAddr();
    const labelgpuList& losortStart = solver_.matrix().lduAddr().losortStartAddr();
    const labelgpuList& losort = solver_.matrix().lduAddr().losortAddr();
    const scalargpuField& Lower = normalMult? 
                                  (fastPath?solver_.matrix().gpuLowerSort():solver_.matrix().gpuLower()):
                                  (fastPath?solver_.matrix().gpuUpperSort():solver_.matrix().gpuUpper());
    const scalargpuField& Upper = normalMult?
                                  solver_.matrix().gpuUpper():
                                  solver_.matrix().gpuLower();
    
    if(fastPath)
    {
        thrust::transform
        (
            thrust::make_counting_iterator(0),
            thrust::make_counting_iterator(0)+r.size(),
            w.begin(),
            AINVPreconditionerFunctor<true,3>
            (
                r.data(),
                rD.data(),
                Lower.data(),
                Upper.data(),
                l.data(),
                u.data(),
                ownStart.data(),
                losortStart.data(),
                losort.data()
            )
        );
    }
    else
    {
        thrust::transform
        (
            thrust::make_counting_iterator(0),
            thrust::make_counting_iterator(0)+r.size(),
            w.begin(),
            AINVPreconditionerFunctor<false,3>
            (
                r.data(),
                rD.data(),
                Lower.data(),
                Upper.data(),
                l.data(),
                u.data(),
                ownStart.data(),
                losortStart.data(),
                losort.data()
            )
        );
    }
}

