// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_HIPRAND_CPP_
#define DLIB_DNN_HIPRAND_CPP_

#ifdef DLIB_USE_ROCM

#include "hiprand_dlibapi.h"
#include <hiprand/hiprand.h>
#include "../string.h"

static const char* hiprand_get_error_string(hiprandStatus_t s)
{
    switch(s)
    {
        case HIPRAND_STATUS_NOT_INITIALIZED: 
            return "ROCM Runtime API initialization failed.";
        case HIPRAND_STATUS_LENGTH_NOT_MULTIPLE:
            return "The requested length must be a multiple of two.";
        default:
            return "A call to hiprand failed";
    }
}

// Check the return value of a call to the HIPDNN runtime for an error condition.
#define CHECK_HIPRAND(call)                                                      \
do{                                                                              \
    const hiprandStatus_t error = call;                                         \
    if (error != HIPRAND_STATUS_SUCCESS)                                        \
    {                                                                          \
        std::ostringstream sout;                                               \
        sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
        sout << "code: " << error << ", reason: " << hiprand_get_error_string(error);\
        throw dlib::hiprand_error(sout.str());                            \
    }                                                                          \
}while(false)

namespace dlib
{
    namespace rocm 
    {

    // ----------------------------------------------------------------------------------------

        hiprand_generator::
        hiprand_generator(
            unsigned long long seed
        ) : handle(nullptr)
        {
            hiprandGenerator_t gen;
            CHECK_HIPRAND(hiprandCreateGenerator(&gen, HIPRAND_RNG_PSEUDO_DEFAULT));
            handle = gen;

            CHECK_HIPRAND(hiprandSetPseudoRandomGeneratorSeed(gen, seed));
        }

        hiprand_generator::
        ~hiprand_generator()
        {
            if (handle)
            {
                hiprandDestroyGenerator((hiprandGenerator_t)handle);
            }
        }

        void hiprand_generator::
        fill_gaussian (
            tensor& data,
            float mean,
            float stddev
        )
        {
            if (data.size() == 0)
                return;

            CHECK_HIPRAND(hiprandGenerateNormal((hiprandGenerator_t)handle, 
                                        data.device(),
                                        data.size(),
                                        mean,
                                        stddev));
        }

        void hiprand_generator::
        fill_uniform (
            tensor& data
        )
        {
            if (data.size() == 0)
                return;

            CHECK_HIPRAND(hiprandGenerateUniform((hiprandGenerator_t)handle, data.device(), data.size()));
        }

        void hiprand_generator::
        fill (
            rocm_data_ptr<unsigned int>& data
        )
        {
            if (data.size() == 0)
                return;

            CHECK_HIPRAND(hiprandGenerate((hiprandGenerator_t)handle, data, data.size()));
        }

    // -----------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_ROCM

#endif // DLIB_DNN_HIPRAND_CPP_

