generate.hpp 1.14 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP

#include <migraph/argument.hpp>
Paul's avatar
Paul committed
5
#include <migraph/literal.hpp>
Paul's avatar
Paul committed
6
7
8
9
#include <random>

namespace migraph {

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
template<class T>
struct xorshf96_generator
{
    unsigned long x=123456789;
    unsigned long y=362436069;
    unsigned long z=521288629;

    constexpr T operator()()
    {
        unsigned long t = 0;
            x ^= x << 16;
            x ^= x >> 5;
            x ^= x << 1;

           t = x;
           x = y;
           y = z;
           z = t ^ x ^ y;

          return z;
    }
};

Paul's avatar
Paul committed
33
34
35
36
37
38
template <class T>
std::vector<T> generate_tensor_data(migraph::shape s, std::mt19937::result_type seed = 0)
{
    std::vector<T> result(s.elements());
    std::mt19937 engine{seed};
    std::uniform_real_distribution<> dist;
39
40
    // std::generate(result.begin(), result.end(), [&] { return dist(engine); });
    std::generate(result.begin(), result.end(), xorshf96_generator<T>{});
Paul's avatar
Paul committed
41
42
43
    return result;
}

Paul's avatar
Paul committed
44
45
46
argument generate_argument(shape s, std::mt19937::result_type seed = 0);

literal generate_literal(shape s, std::mt19937::result_type seed = 0);
Paul's avatar
Paul committed
47
48
49
50

} // namespace migraph

#endif