/*---------------------------------------------------------------------------*\
  =========                 |
  \\      /  F ield         | OpenFOAM: The Open Source CFD Toolbox
   \\    /   O peration     |
    \\  /    A nd           | www.openfoam.com
     \\/     M anipulation  |
-------------------------------------------------------------------------------
    Copyright (C) 2011-2016 OpenFOAM Foundation
    Copyright (C) 2019-2020 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
    This file is part of OpenFOAM.

    OpenFOAM is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    OpenFOAM is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License
    along with OpenFOAM.  If not, see <http://www.gnu.org/licenses/>.

\*---------------------------------------------------------------------------*/

#include "tensorField.H"
#include "transformField.H"

#define TEMPLATE
#include "FieldFunctionsM.C"


// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

namespace Foam
{

// * * * * * * * * * * * * * * * Global Functions  * * * * * * * * * * * * * //

UNARY_FUNCTION(scalar, tensor, tr)
UNARY_FUNCTION(sphericalTensor, tensor, sph)
UNARY_FUNCTION(symmTensor, tensor, symm)
UNARY_FUNCTION(symmTensor, tensor, twoSymm)
UNARY_FUNCTION(tensor, tensor, skew)
UNARY_FUNCTION(tensor, tensor, dev)
UNARY_FUNCTION(tensor, tensor, dev2)
UNARY_FUNCTION(scalar, tensor, det)
UNARY_FUNCTION(tensor, tensor, cof)

void inv(Field<tensor>& tf, const UList<tensor>& tf1)
{
    if (tf.empty())
    {
        return;
    }

    scalar scale = magSqr(tf1[0]);
    Vector<bool> removeCmpts
    (
        magSqr(tf1[0].xx())/scale < SMALL,
        magSqr(tf1[0].yy())/scale < SMALL,
        magSqr(tf1[0].zz())/scale < SMALL
    );

    if (removeCmpts.x() || removeCmpts.y() || removeCmpts.z())
    {
        tensorField tf1Plus(tf1);

        if (removeCmpts.x())
        {
            tf1Plus += tensor(1,0,0,0,0,0,0,0,0);
        }

        if (removeCmpts.y())
        {
            tf1Plus += tensor(0,0,0,0,1,0,0,0,0);
        }

        if (removeCmpts.z())
        {
            tf1Plus += tensor(0,0,0,0,0,0,0,0,1);
        }

        TFOR_ALL_F_OP_FUNC_F(tensor, tf, =, inv, tensor, tf1Plus)

        if (removeCmpts.x())
        {
            tf -= tensor(1,0,0,0,0,0,0,0,0);
        }

        if (removeCmpts.y())
        {
            tf -= tensor(0,0,0,0,1,0,0,0,0);
        }

        if (removeCmpts.z())
        {
            tf -= tensor(0,0,0,0,0,0,0,0,1);
        }
    }
    else
    {
        TFOR_ALL_F_OP_FUNC_F(tensor, tf, =, inv, tensor, tf1)
    }
}

tmp<tensorField> inv(const UList<tensor>& tf)
{
    auto tres = tmp<tensorField>::New(tf.size());
    inv(tres.ref(), tf);
    return tres;
}

tmp<tensorField> inv(const tmp<tensorField>& tf)
{
    auto tres = New(tf);
    inv(tres.ref(), tf());
    tf.clear();
    return tres;
}

UNARY_FUNCTION(vector, symmTensor, eigenValues)
UNARY_FUNCTION(tensor, symmTensor, eigenVectors)


template<>
tmp<Field<tensor>> transformFieldMask<tensor>
(
    const symmTensorField& stf
)
{
    auto tres = tmp<tensorField>::New(stf.size());
    auto& res = tres.ref();
    TFOR_ALL_F_OP_F(tensor, res, =, symmTensor, stf)
    return tres;
}

template<>
tmp<Field<tensor>> transformFieldMask<tensor>
(
    const tmp<symmTensorField>& tstf
)
{
    tmp<Field<tensor>> ret = transformFieldMask<tensor>(tstf());
    tstf.clear();
    return ret;
}


// * * * * * * * * * * * * * * * global operators  * * * * * * * * * * * * * //

UNARY_OPERATOR(vector, tensor, *, hdual)
UNARY_OPERATOR(tensor, vector, *, hdual)

BINARY_OPERATOR(vector, vector, tensor, /, divide)
BINARY_TYPE_OPERATOR(vector, vector, tensor, /, divide)


// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

} // End namespace Foam

// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

#include "undefFieldFunctionsM.H"

// ************************************************************************* //

#define TEMPLATE
#include "gpuFieldFunctionsM.C"
#include "gpuList.C"

// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

namespace Foam
{

template class gpuList<tensor>;
template class gpuField<tensor>;

struct tensorRemoveComponentsFunctor : public thrust::unary_function<symmTensor, Vector<bool> >
{
    const scalar scale;
    tensorRemoveComponentsFunctor(scalar _scale): scale(_scale) {}
    __host__ __device__
    Vector<bool> operator()(const tensor& st) const 
    {
        return Vector<bool>
        (
            st.xx()/scale < SMALL,
            st.yy()/scale < SMALL,
            st.zz()/scale < SMALL
        );
    }
};

struct andBooleanVectorTensorFunctor : public thrust::binary_function<Vector<bool>,Vector<bool>,Vector<bool> >
{
    __host__ __device__
    Vector<bool> operator()(const Vector<bool>& v1,const Vector<bool>& v2) const 
    {
        return Vector<bool>
        (
            v1.x()&&v2.x(),
            v1.y()&&v2.y(),
            v1.z()&&v2.z()
        );
    }
};

// * * * * * * * * * * * * * * * global functions  * * * * * * * * * * * * * //

UNARY_FUNCTION(scalar, tensor, tr)
UNARY_FUNCTION(sphericalTensor, tensor, sph)
UNARY_FUNCTION(symmTensor, tensor, symm)
UNARY_FUNCTION(symmTensor, tensor, twoSymm)
UNARY_FUNCTION(tensor, tensor, skew)
UNARY_FUNCTION(tensor, tensor, dev)
UNARY_FUNCTION(tensor, tensor, dev2)
UNARY_FUNCTION(scalar, tensor, det)
UNARY_FUNCTION(tensor, tensor, cof)

void inv(gpuField<tensor>& tf, const gpuList<tensor>& tf1)
{
    if (tf.empty())
    {
        return;
    }

    gpuList<tensor> tmp(tf1,1);
    scalar scale = sum(magSqr(tmp));

    Vector<bool> removeCmpts = 
        thrust::reduce
        (
            thrust::make_transform_iterator
            (
                tmp.begin(),
                tensorRemoveComponentsFunctor(scale)
            ),
            thrust::make_transform_iterator
            (
                tmp.end(),
                tensorRemoveComponentsFunctor(scale)
            ),
            Vector<bool>(true,true,true),
            andBooleanVectorTensorFunctor()
        );

    if (removeCmpts.x() || removeCmpts.y() || removeCmpts.z())
    {
        tensorgpuField tf1Plus(tf1);

        if (removeCmpts.x())
        {
            tf1Plus += tensor(1,0,0,0,0,0,0,0,0);
        }

        if (removeCmpts.y())
        {
            tf1Plus += tensor(0,0,0,0,1,0,0,0,0);
        }

        if (removeCmpts.z())
        {
            tf1Plus += tensor(0,0,0,0,0,0,0,0,1);
        }

        thrust::transform
        (
            tf1Plus.begin(),
            tf1Plus.end(),
            tf.begin(),
            invUnaryFunctionFunctor<tensor,tensor>()
        );

        if (removeCmpts.x())
        {
            tf -= tensor(1,0,0,0,0,0,0,0,0);
        }

        if (removeCmpts.y())
        {
            tf -= tensor(0,0,0,0,1,0,0,0,0);
        }

        if (removeCmpts.z())
        {
            tf -= tensor(0,0,0,0,0,0,0,0,1);
        }
    }
    else
    {
        thrust::transform
        (
            tf1.begin(),
            tf1.end(),
            tf.begin(),
            invUnaryFunctionFunctor<tensor,tensor>()
        );
    }
}

tmp<tensorgpuField> inv(const gpuList<tensor>& tf)
{
    tmp<tensorgpuField> result(new tensorgpuField(tf.size()));
    inv(result.ref(), tf);
    return result;
}

tmp<tensorgpuField> inv(const tmp<tensorgpuField>& tf)
{
    tmp<tensorgpuField> tRes = reusegpuTmp<tensor, tensor>::New(tf);
    inv(tRes.ref(), tf.ref());
    reusegpuTmp<tensor, tensor>::clear(tf);
    return tRes;
}

//UNARY_FUNCTION(vector, tensor, eigenValues)
//UNARY_FUNCTION(tensor, tensor, eigenVectors)

//UNARY_FUNCTION(vector, symmTensor, eigenValues)
//UNARY_FUNCTION(tensor, symmTensor, eigenVectors)

template<>
tmp<gpuField<tensor> > transformFieldMask<tensor>
(
    const symmTensorgpuField& stf
)
{
    tmp<tensorgpuField> tRes(new tensorgpuField(stf.size()));
    tensorgpuField& res = tRes.ref();
    thrust::transform
    (
        stf.begin(),
        stf.end(),
        res.begin(),
        assignFunctor<symmTensor,tensor>()
    );
    return tRes;
}

template<>
tmp<gpuField<tensor> > transformFieldMask<tensor>
(
    const tmp<symmTensorgpuField>& tstf
)
{
    tmp<gpuField<tensor> > ret = transformFieldMask<tensor>(tstf());
    tstf.clear();
    return ret;
}


// * * * * * * * * * * * * * * * global operators  * * * * * * * * * * * * * //

UNARY_OPERATOR(vector, tensor, *, hdual)
UNARY_OPERATOR(tensor, vector, *, hdual)

BINARY_SYM_OPERATOR(tensor, scalar, tensor, *, outer)
BINARY_SYM_FUNCTION(tensor, scalar, tensor, multiply)
BINARY_OPERATOR(tensor, tensor, scalar, /, divide)
BINARY_TYPE_OPERATOR_FS(tensor, tensor, scalar, /, divide)

BINARY_FULL_OPERATOR(tensor, tensor, tensor, +, add)
BINARY_FULL_OPERATOR(tensor, tensor, tensor, -, subtract)

BINARY_FULL_OPERATOR(tensor, tensor, tensor, &, dot)
BINARY_SYM_OPERATOR(vector, vector, tensor, &, dot)

BINARY_FULL_OPERATOR(tensor, vector, vector, *, outer)
BINARY_FULL_OPERATOR(vector, vector, tensor, /, divide)

BINARY_SYM_OPERATOR(tensor, sphericalTensor, tensor, +, add)
BINARY_SYM_OPERATOR(tensor, sphericalTensor, tensor, -, subtract)
BINARY_SYM_OPERATOR(tensor, sphericalTensor, tensor, &, dot)
BINARY_SYM_OPERATOR(scalar, sphericalTensor, tensor, &&, dotdot)

BINARY_SYM_OPERATOR(tensor, symmTensor, tensor, +, add)
BINARY_SYM_OPERATOR(tensor, symmTensor, tensor, -, subtract)
BINARY_SYM_OPERATOR(tensor, symmTensor, tensor, &, dot)
BINARY_SYM_OPERATOR(scalar, symmTensor, tensor, &&, dotdot)


// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

} // End namespace Foam

// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

#include "undefgpuFieldFunctionsM.H"

