// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_HIPRAND_H_
#define DLIB_DNN_HIPRAND_H_

#ifdef DLIB_USE_ROCM

#include "tensor.h"
#include "hip_errors.h"
#include "rocm_data_ptr.h"

namespace dlib
{
    namespace rocm 
    {

    // -----------------------------------------------------------------------------------

        class hiprand_generator
        {
        public:
            // not copyable
            hiprand_generator(const hiprand_generator&) = delete;
            hiprand_generator& operator=(const hiprand_generator&) = delete;

            hiprand_generator() : hiprand_generator(0) {}
            hiprand_generator(unsigned long long seed);
            ~hiprand_generator();

            void fill (
                rocm_data_ptr<unsigned int>& data
            );
            /*!
                ensures
                    - Fills data with random 32-bit unsigned integers.
            !*/

            void fill_gaussian (
                tensor& data,
                float mean = 0,
                float stddev = 1
            );
            /*!
                requires
                    - data.size()%2 == 0
                    - stddev >= 0
                ensures
                    - Fills data with random numbers drawn from a Gaussian distribution
                      with the given mean and standard deviation.
            !*/

            void fill_uniform (
                tensor& data
            );
            /*!
                ensures
                    - Fills data with uniform random numbers in the range (0.0, 1.0].
            !*/

        private:

            void* handle;
        };

    // -----------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_ROCM

#endif // DLIB_DNN_HIPRAND_H_



