/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2016 OpenFOAM Foundation
    Copyright (C) 2016 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpulduAddressing.H"
#include "demandDrivenData.H"
#include "scalarField.H"

#include <thrust/iterator/discard_iterator.h>
#include <thrust/scan.h>
#include <thrust/unique.h>

// * * * * * * * * * * * * * Private Member Functions  * * * * * * * * * * * //

void Foam::gpulduAddressing::calcLosort() const
{
    if (losortAddr_)
    {
        FatalErrorInFunction
            << "losort already calculated"
            << abort(FatalError);
    }

    const labelgpuList& nbr = upperAddr();

    losortAddr_ = new labelgpuList(nbr.size());

    labelgpuList& lst = *losortAddr_;

    labelgpuList nbrTmp(nbr);

    thrust::counting_iterator<label> first(0);
    thrust::copy
    (
        first,
        first+nbr.size(),
        lst.begin()
    );

    thrust::stable_sort_by_key
    (
        nbrTmp.begin(),
        nbrTmp.end(),
        lst.begin()
    );
}


void Foam::gpulduAddressing::calcOwnerStart() const
{
    if (ownerStartAddr_)
    {
        FatalErrorInFunction
            << "owner start already calculated"
            << abort(FatalError);
    }

    const labelgpuList& own = lowerAddr();

    ownerStartAddr_ = new labelgpuList(size() + 1, own.size());

    labelgpuList& ownStart = *ownerStartAddr_;

    labelgpuList ones(own.size()+size(),1);
    labelgpuList tmpSum(size());

    labelgpuList ownSort(own.size()+size());


    thrust::copy
    (
        own.begin(),
        own.end(),
        ownSort.begin()
    );

    thrust::copy
    (
        thrust::make_counting_iterator(0),
        thrust::make_counting_iterator(0)+size(),
        ownSort.begin()+own.size()
    );

    thrust::fill
    (
        ones.begin()+own.size(),
        ones.end(),
        0
    );

    thrust::stable_sort_by_key
    (
        ownSort.begin(),
        ownSort.end(),
        ones.begin()
    );


    thrust::reduce_by_key
    (
        ownSort.begin(),
        ownSort.end(),
        ones.begin(),
        thrust::make_discard_iterator(),
        tmpSum.begin()
    );

    thrust::exclusive_scan
    (
        tmpSum.begin(),
        tmpSum.end(),
        ownStart.begin()
    );
}


void Foam::gpulduAddressing::calcLosortStart() const
{
    if (losortStartAddr_)
    {
        FatalErrorInFunction
            << "losort start already calculated"
            << abort(FatalError);
    }


    const labelgpuList& nbr = upperAddr();

    const labelgpuList& lsrt = losortAddr();

    losortStartAddr_ = new labelgpuList(size() + 1, nbr.size());

    labelgpuList& lsrtStart = *losortStartAddr_;

    labelgpuList ones(nbr.size()+size(),1);
    labelgpuList tmpSum(size());

    labelgpuList nbrSort(nbr.size()+size());

    thrust::copy
    (
        thrust::make_permutation_iterator
        (
            nbr.begin(),
            lsrt.begin()
        ),
        thrust::make_permutation_iterator
        (
            nbr.begin(),
            lsrt.end()
        ),
        nbrSort.begin()
    );

    thrust::copy
    (
        thrust::make_counting_iterator(0),
        thrust::make_counting_iterator(0)+size(),
        nbrSort.begin()+nbr.size()
    );

    thrust::fill
    (
        ones.begin()+nbr.size(),
        ones.end(),
        0
    );


    thrust::stable_sort_by_key
    (
        nbrSort.begin(),
        nbrSort.end(),
        ones.begin()
    );

    thrust::reduce_by_key
    (
        nbrSort.begin(),
        nbrSort.end(),
        ones.begin(),
        thrust::make_discard_iterator(),
        tmpSum.begin()
    );

    thrust::exclusive_scan
    (
        tmpSum.begin(),
        tmpSum.begin()+size(),
        lsrtStart.begin()
    );
}

//*********************************************************//

void Foam::gpulduAddressing::calcPatchSort() const
{
    if (gPatchSortAddr_.size() == nPatches())
    {
    	FatalErrorInFunction        
			<< "patch sort already calculated"
            << abort(FatalError);
    }

    gPatchSortAddr_.setSize(nPatches());
    gPatchSortCells_.setSize(nPatches());

    for(label i = 0; i < nPatches(); i++)
    {
        if( ! patchAvailable(i))
            continue;

        const labelgpuList& nbr = patchAddr(i);
        labelgpuList* sortPtr_ = new labelgpuList(nbr.size(), -1);
        gPatchSortAddr_.set(i,sortPtr_);

        labelgpuList& lst = *sortPtr_;

        labelgpuList nbrTmp(nbr);

        thrust::counting_iterator<label> first(0);
        thrust::copy
        (
            first,
            first+nbr.size(),
            lst.begin()
        );

        thrust::stable_sort_by_key
        (
            nbrTmp.begin(),
            nbrTmp.end(),
            lst.begin()
        );

        labelgpuList* cellsSortPtr= new labelgpuList(nbr.size());
        gPatchSortCells_.set(i,cellsSortPtr);
        labelgpuList& cellsSort = *cellsSortPtr;

        thrust::copy
        (
            thrust::make_permutation_iterator
            (
                nbr.begin(),
                lst.begin()
            ),
            thrust::make_permutation_iterator
            (
                nbr.begin(),
                lst.end()
            ),
            cellsSort.begin()
        );

        cellsSort.setSize
        (
            thrust::unique(cellsSort.begin(),cellsSort.end()) - cellsSort.begin()
        );
    }
}


void Foam::gpulduAddressing::calcPatchSortStart() const
{
    if (gPatchSortStartAddr_.size() == nPatches())
    {
		FatalErrorInFunction
            << "losort start already calculated"
            << abort(FatalError);
    }

    gPatchSortStartAddr_.setSize(nPatches());

    for(label i = 0; i < nPatches(); i++)
    {
        if( ! patchAvailable(i))
            continue;

        const labelgpuList& nbr = patchAddr(i);

        const labelgpuList& lsrt = gpuPatchSortAddr(i);

        labelgpuList* patchSortStartPtr_ = new labelgpuList(nbr.size() + 1, nbr.size());

        gPatchSortStartAddr_.set(i,patchSortStartPtr_);

        labelgpuList& lsrtStart = *patchSortStartPtr_;

        labelgpuList ones(nbr.size(),1);
        labelgpuList tmpSum(nbr.size());

        labelgpuList nbrSort(nbr.size());

        thrust::copy
        (
            thrust::make_permutation_iterator
            (
                nbr.begin(),
                lsrt.begin()
            ),
            thrust::make_permutation_iterator
            (
                nbr.begin(),
                lsrt.end()
            ),
            nbrSort.begin()
        );

        thrust::reduce_by_key
        (
            nbrSort.begin(),
            nbrSort.end(),
            ones.begin(),
            thrust::make_discard_iterator(),
            tmpSum.begin()
        );

        thrust::exclusive_scan
        (
            tmpSum.begin(),
            tmpSum.end(),
            lsrtStart.begin()
        );
    }
}

// * * * * * * * * * * * * * * * * Destructor  * * * * * * * * * * * * * * * //

Foam::gpulduAddressing::~gpulduAddressing()
{
    deleteDemandDrivenData(losortAddr_);
    deleteDemandDrivenData(ownerStartAddr_);
    deleteDemandDrivenData(losortStartAddr_);
    deleteDemandDrivenData(gOwnerSortAddr_);

    gPatchSortCells_.clear();
    gPatchSortAddr_.clear();
    gPatchSortStartAddr_.clear();
}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

const Foam::labelgpuList& Foam::gpulduAddressing::losortAddr() const
{
    if (!losortAddr_)
    {
        calcLosort();
    }

    return *losortAddr_;
}


const Foam::labelgpuList& Foam::gpulduAddressing::ownerStartAddr() const
{
    if (!ownerStartAddr_)
    {
        calcOwnerStart();
    }

    return *ownerStartAddr_;
}


const Foam::labelgpuList& Foam::gpulduAddressing::losortStartAddr() const
{
    if (!losortStartAddr_)
    {
        calcLosortStart();
    }

    return *losortStartAddr_;
}


void Foam::gpulduAddressing::clearOut()
{
    deleteDemandDrivenData(losortAddr_);
    deleteDemandDrivenData(ownerStartAddr_);
    deleteDemandDrivenData(losortStartAddr_);
}


Foam::label Foam::gpulduAddressing::triIndex(const label a, const label b) const
{
    label own = min(a, b);

    label nbr = max(a, b);

    label startLabel = ownerStartAddr().get(own);

    label endLabel = ownerStartAddr().get(own + 1);

    const labelgpuList& neighbour = upperAddr();

    for (label i=startLabel; i<endLabel; i++)
    {
        if (neighbour.get(i) == nbr)
        {
            return i;
        }
    }

    // If neighbour has not been found, something has gone seriously
    // wrong with the addressing mechanism
    FatalErrorInFunction
        << "neighbour " << nbr << " not found for owner " << own << ". "
        << "Problem with addressing"
        << abort(FatalError);

    return -1;
}


Foam::Tuple2<Foam::label, Foam::scalar> Foam::gpulduAddressing::band() const
{
    const labelgpuList& owner = lowerAddr();
    const labelgpuList& neighbour = upperAddr();

    labelgpuList cellBandwidth(size(), 0);
    labelgpuList diffs(neighbour.size(),0);

    thrust::transform
    (
        neighbour.begin(),
        neighbour.end(),
        owner.begin(),
        diffs.begin(),
        subtractOperatorFunctor<label,label,label>()
    );

    thrust::transform
    (
        diffs.begin(),
        diffs.end(),
        thrust::make_permutation_iterator
        (
            cellBandwidth.begin(),
            neighbour.begin()
        ),
        thrust::make_permutation_iterator
        (
            cellBandwidth.begin(),
            neighbour.begin()
        ),
        maxBinaryFunctionFunctor<label,label,label>()
    );

    label bandwidth = max(cellBandwidth);

    // Do not use field algebra because of conversion label to scalar
    scalar profile =
        thrust::reduce
        (
            cellBandwidth.begin(),
            cellBandwidth.end()
        );

    return Tuple2<label, scalar>(bandwidth, profile);
}


//************************************************************************//

const Foam::labelgpuList& Foam::gpulduAddressing::gpuOwnerSortAddr() const
{
    if (!gOwnerSortAddr_)
    {
        const labelgpuList& own = lowerAddr();
        const labelgpuList& lsrt = losortAddr();

        gOwnerSortAddr_ = new labelgpuList(own.size());
        labelgpuList& ownSort = *gOwnerSortAddr_;

        thrust::copy
        (
            thrust::make_permutation_iterator
            (
                own.begin(),
                lsrt.begin()
            ),
            thrust::make_permutation_iterator
            (
                own.begin(),
                lsrt.end()
            ),
            ownSort.begin()
        );
    }

    return *gOwnerSortAddr_;
}

const Foam::labelgpuList& Foam::gpulduAddressing::gpuPatchSortCells(const label i) const
{
    if (gPatchSortCells_.size() != nPatches())
    {
        calcPatchSort();
    }

    return gPatchSortCells_[i];
}

const Foam::labelgpuList& Foam::gpulduAddressing::gpuPatchSortAddr(const label i) const
{
    if (gPatchSortAddr_.size() != nPatches())
    {
        calcPatchSort();
    }

    return gPatchSortAddr_[i];
}

const Foam::labelgpuList& Foam::gpulduAddressing::gpuPatchSortStartAddr(const label i) const
{
    if (gPatchSortStartAddr_.size() != nPatches())
    {
        calcPatchSortStart();
    }

    return gPatchSortStartAddr_[i];
}
// ************************************************************************* //
