/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2016 OpenFOAM Foundation
    Copyright (C) 2020 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "gpumeshWavePatchDistMethod.H"
#include "gpufvMesh.H"
#include "volgpuFields.H"
#include "patchWave.H"
#include "patchDataWave.H"
#include "wallPointData.H"
#include "emptyFvPatchgpuFields.H"
#include "addToRunTimeSelectionTable.H"

// * * * * * * * * * * * * * * Static Data Members * * * * * * * * * * * * * //

namespace Foam
{
namespace gpupatchDistMethods
{
    defineTypeNameAndDebug(meshWave, 0);
    addToRunTimeSelectionTable(gpupatchDistMethod, meshWave, dictionary);
}
}

// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

Foam::gpupatchDistMethods::meshWave::meshWave
(
    const dictionary& dict,
    const gpufvMesh& mesh,
    const labelHashSet& patchIDs
)
:
    gpupatchDistMethod(mesh, patchIDs),
    correctWalls_(dict.getOrDefault("correctWalls", true)),
    nUnset_(0)
{}


Foam::gpupatchDistMethods::meshWave::meshWave
(
    const gpufvMesh& mesh,
    const labelHashSet& patchIDs,
    const bool correctWalls
)
:
    gpupatchDistMethod(mesh, patchIDs),
    correctWalls_(correctWalls),
    nUnset_(0)
{}


// * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //

bool Foam::gpupatchDistMethods::meshWave::correct(volScalargpuField& y)
{
    y = dimensionedScalar("yWall", dimLength, GREAT);

    // Calculate distance starting from patch faces
    patchWave wave(mesh_.hostmesh(), patchIDs_, correctWalls_);

    // Transfer cell values from wave into y
    y.transfer(wave.gdistance());

    // Transfer values on patches into boundaryField of y
    volScalargpuField::Boundary& ybf = y.boundaryFieldRef();

    forAll(ybf, patchi)
    {
        if (!isA<emptyFvPatchScalargpuField>(ybf[patchi]))
        {
            scalargpuField& waveFld = wave.gpatchDistance()[patchi];

            ybf[patchi].transfer(waveFld);
        }
    }

    // Transfer number of unset values
    nUnset_ = wave.nUnset();

    return nUnset_ > 0;
}


bool Foam::gpupatchDistMethods::meshWave::correct
(
    volScalargpuField& y,
    volVectorgpuField& n
)
{
    y = dimensionedScalar("yWall", dimLength, GREAT);

    // Collect pointers to data on patches
    UPtrList<vectorField> patchData(mesh_.hostmesh().boundaryMesh().size());

    volVectorgpuField::Boundary& nbf = n.boundaryFieldRef();

    forAll(nbf, patchi)
    {
        vectorField* nbfPtr= new vectorField(nbf[patchi].size());
		thrust::copy(nbf[patchi].begin(),nbf[patchi].end(),nbfPtr->begin());
		patchData.set(patchi, nbfPtr);
    }

    // Do mesh wave
    patchDataWave<wallPointData<vector>> wave
    (
        mesh_.hostmesh(),
        patchIDs_,
        patchData,
        correctWalls_
    );

    // Transfer cell values from wave into y and n
    y.transfer(wave.gdistance());

    n.transfer(wave.gcellData());

    // Transfer values on patches into boundaryField of y and n
    volScalargpuField::Boundary& ybf = y.boundaryFieldRef();

    forAll(ybf, patchi)
    {
        scalargpuField& waveFld = wave.gpatchDistance()[patchi];

        if (!isA<emptyFvPatchScalargpuField>(ybf[patchi]))
        {
            ybf[patchi].transfer(waveFld);

            vectorgpuField& wavePatchData = wave.gpatchData()[patchi];

            nbf[patchi].transfer(wavePatchData);
        }
    }

    // Transfer number of unset values
    nUnset_ = wave.nUnset();

    return nUnset_ > 0;
}


// ************************************************************************* //
