// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDA_H_
#define DLIB_DNN_CuDA_H_

#ifdef DLIB_USE_CUDA

#include "tensor.h"

namespace dlib
{
    namespace cuda 
    {

    // -----------------------------------------------------------------------------------

        void multiply (
            tensor& dest,
            const tensor& src
        );

    // -----------------------------------------------------------------------------------

        void affine_transform(
            resizable_tensor& dest,
            const tensor& src,
            const float A,
            const float B
        );

    // -----------------------------------------------------------------------------------

        void affine_transform(
            resizable_tensor& dest,
            const tensor& src,
            const tensor& A,
            const tensor& B
        );

    // -----------------------------------------------------------------------------------

        void batch_normalize (
            resizable_tensor& dest,
            resizable_tensor& means,
            resizable_tensor& invstds,
            const tensor& src,
            const tensor& gamma, 
            const tensor& beta 
        );

        class batch_normalize_gradient
        {
        public:
            void operator() (
                const tensor& gradient_input,
                const tensor& means,
                const tensor& invstds,
                const tensor& src,
                const tensor& gamma,
                tensor& src_grad,
                tensor& gamma_grad, 
                tensor& beta_grad 
            );
        private:
            resizable_tensor dvars, dmeans;
        };

        void batch_normalize_conv (
            resizable_tensor& dest,
            resizable_tensor& means,
            resizable_tensor& invstds,
            const tensor& src,
            const tensor& gamma, 
            const tensor& beta 
        );

        class batch_normalize_conv_gradient
        {
        public:
            void operator() (
                const tensor& gradient_input,
                const tensor& means,
                const tensor& invstds,
                const tensor& src,
                const tensor& gamma,
                tensor& src_grad,
                tensor& gamma_grad, 
                tensor& beta_grad 
            );
        private:
            resizable_tensor dvars, dmeans;
        };

    // -----------------------------------------------------------------------------------

        void threshold (
            tensor& data,
            float thresh
        );

    // ------------------------------------------------------------------------------------

    } 
}

#endif // DLIB_USE_CUDA

#endif // DLIB_DNN_CuDA_H_

