generate.hpp 2.41 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP

#include <migraph/argument.hpp>
Paul's avatar
Paul committed
5
#include <migraph/literal.hpp>
Paul's avatar
Paul committed
6
#include <migraph/type_traits.hpp>
7
#include <migraph/config.hpp>
Paul's avatar
Paul committed
8
9
#include <random>

10
namespace migraph { inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
template <class T, MIGRAPH_REQUIRES(is_floating_point<T>{})>
Paul's avatar
Paul committed
13
constexpr T normalize(unsigned long z)
14
{
Paul's avatar
Paul committed
15
    if(z == 0)
Paul's avatar
Paul committed
16
        return T(0);
Paul's avatar
Paul committed
17
    const auto max     = 32;
Paul's avatar
Paul committed
18
    const double range = max / 2; // NOLINT
Paul's avatar
Paul committed
19
    double result      = (z % max) / range;
Paul's avatar
Latest  
Paul committed
20
    result -= 1;
Paul's avatar
Paul committed
21
    return T(result);
22
23
}

Paul's avatar
Paul committed
24
template <class T, MIGRAPH_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
Paul's avatar
Paul committed
25
constexpr T normalize(unsigned long z)
26
{
Paul's avatar
Paul committed
27
28
    const auto max      = std::numeric_limits<T>::max();
    const auto half_max = max / 2;
29
30
31
    return half_max - (z % max);
}

Paul's avatar
Paul committed
32
template <class T, MIGRAPH_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
Paul's avatar
Paul committed
33
constexpr T normalize(unsigned long z)
34
35
36
37
38
{
    const auto max = std::numeric_limits<T>::max();
    return z % max;
}

Paul's avatar
Paul committed
39
template <class T>
40
41
struct xorshf96_generator
{
Paul's avatar
Paul committed
42
43
    unsigned long x = 123456789;
    unsigned long y = 362436069;
Paul's avatar
Paul committed
44
45
    unsigned long z;

Paul's avatar
Paul committed
46
    xorshf96_generator(unsigned long seed = 0) : z(521288629ULL ^ seed) {}
47

48
    constexpr T operator()() noexcept
49
    {
50
51
52
        x ^= x << 16U;
        x ^= x >> 5U;
        x ^= x << 1U;
53

Paul's avatar
Paul committed
54
        unsigned long t = x;
Paul's avatar
Paul committed
55
56
57
        x               = y;
        y               = z;
        z               = t ^ x ^ y;
58

Paul's avatar
Paul committed
59
        return normalize<T>(z);
60
61
62
    }
};

Paul's avatar
Latest  
Paul committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
template <class T>
struct xorshift_generator
{
    unsigned long x;

    xorshift_generator(unsigned long seed = 0) : x(521288629ULL ^ seed) {}

    constexpr T operator()() noexcept
    {
        x ^= x >> 12U;
        x ^= x << 25U;
        x ^= x >> 27U;
        return normalize<T>(x * 0x2545F4914F6CDD1D);
    }
};

Paul's avatar
Paul committed
79
template <class T>
Paul's avatar
Paul committed
80
std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
Paul's avatar
Paul committed
81
82
{
    std::vector<T> result(s.elements());
Paul's avatar
Paul committed
83
    std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
Paul's avatar
Latest  
Paul committed
84
    // std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
Paul's avatar
Paul committed
85
    // std::generate(result.begin(), result.end(), []{ return 1; });
Paul's avatar
Paul committed
86
87
88
    return result;
}

Paul's avatar
Paul committed
89
argument generate_argument(shape s, unsigned long seed = 0);
Paul's avatar
Paul committed
90

Paul's avatar
Paul committed
91
92
93
literal generate_literal(shape s, unsigned long seed = 0);

literal abs(literal l);
Paul's avatar
Paul committed
94

95
} // inline namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
96
97
98
} // namespace migraph

#endif