generate.hpp 1.52 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP

#include <migraph/argument.hpp>
Paul's avatar
Paul committed
5
#include <migraph/literal.hpp>
Paul's avatar
Paul committed
6
7
8
9
#include <random>

namespace migraph {

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
template<class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})>
T normalize(T z)
{
    if(z == 0) return 0;
    return (2.0 / z) - 1.0;
}

template<class T, MIGRAPH_REQUIRES(std::is_signed<T>{})>
T normalize(T z)
{
    const auto max = std::numeric_limits<T>::max();
    const auto half_max = max/2;
    return half_max - (z % max);
}

template<class T, MIGRAPH_REQUIRES(not std::is_signed<T>{} and std::is_integral<T>{})>
T normalize(T z)
{
    const auto max = std::numeric_limits<T>::max();
    return z % max;
}

Paul's avatar
Paul committed
32
template <class T>
33
34
struct xorshf96_generator
{
Paul's avatar
Paul committed
35
36
37
    unsigned long x = 123456789;
    unsigned long y = 362436069;
    unsigned long z = 521288629;
38

39
    constexpr T operator()() noexcept
40
    {
41
42
43
        x ^= x << 16U;
        x ^= x >> 5U;
        x ^= x << 1U;
44

Paul's avatar
Paul committed
45
        unsigned long t = x;
Paul's avatar
Paul committed
46
47
48
        x               = y;
        y               = z;
        z               = t ^ x ^ y;
49

50
51
        return normalize(z);

52
53
54
    }
};

Paul's avatar
Paul committed
55
template <class T>
Paul's avatar
Paul committed
56
std::vector<T> generate_tensor_data(const migraph::shape& s, std::mt19937::result_type)
Paul's avatar
Paul committed
57
58
{
    std::vector<T> result(s.elements());
59
    std::generate(result.begin(), result.end(), xorshf96_generator<T>{});
Paul's avatar
Paul committed
60
61
62
    return result;
}

Paul's avatar
Paul committed
63
64
65
argument generate_argument(shape s, std::mt19937::result_type seed = 0);

literal generate_literal(shape s, std::mt19937::result_type seed = 0);
Paul's avatar
Paul committed
66
67
68
69

} // namespace migraph

#endif