generate.hpp 2.61 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
7
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
8
9
#include <random>

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
Paul's avatar
Paul committed
14
constexpr T normalize(unsigned long z)
15
{
Paul's avatar
Paul committed
16
    if(z == 0)
Paul's avatar
Paul committed
17
        return T(0);
Paul's avatar
Paul committed
18
    const auto max     = 32;
Paul's avatar
Paul committed
19
    const double range = max / 2; // NOLINT
Paul's avatar
Paul committed
20
    double result      = double(z % max) / range;
Paul's avatar
Latest  
Paul committed
21
    result -= 1;
Paul's avatar
Paul committed
22
    return T(result);
23
24
}

Paul's avatar
Paul committed
25
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
Paul's avatar
Paul committed
26
constexpr T normalize(unsigned long z)
27
{
28
    const auto max      = std::numeric_limits<T>::max() / 64;
Paul's avatar
Paul committed
29
    const auto half_max = max / 2;
30
31
32
    return half_max - (z % max);
}

Paul's avatar
Paul committed
33
template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
Paul's avatar
Paul committed
34
constexpr T normalize(unsigned long z)
35
{
36
    const auto max = std::numeric_limits<T>::max() / 64;
37
38
39
    return z % max;
}

Paul's avatar
Paul committed
40
template <class T>
41
42
struct xorshf96_generator
{
Paul's avatar
Paul committed
43
44
    unsigned long x = 123456789;
    unsigned long y = 362436069;
Paul's avatar
Paul committed
45
46
    unsigned long z;

Paul's avatar
Paul committed
47
    xorshf96_generator(unsigned long seed = 0) : z(521288629ULL ^ seed) {}
48

49
    constexpr T operator()() noexcept
50
    {
51
52
53
        x ^= x << 16U;
        x ^= x >> 5U;
        x ^= x << 1U;
54

Paul's avatar
Paul committed
55
        unsigned long t = x;
Paul's avatar
Paul committed
56
57
58
        x               = y;
        y               = z;
        z               = t ^ x ^ y;
59

Paul's avatar
Paul committed
60
        return normalize<T>(z);
61
62
63
    }
};

Paul's avatar
Latest  
Paul committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
template <class T>
struct xorshift_generator
{
    unsigned long x;

    xorshift_generator(unsigned long seed = 0) : x(521288629ULL ^ seed) {}

    constexpr T operator()() noexcept
    {
        x ^= x >> 12U;
        x ^= x << 25U;
        x ^= x >> 27U;
        return normalize<T>(x * 0x2545F4914F6CDD1D);
    }
};

Paul's avatar
Paul committed
80
template <class T>
81
auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
Paul's avatar
Paul committed
82
{
83
84
    auto result = make_shared_array<T>(s.elements());
    std::generate(result.get(), result.get() + s.elements(), xorshf96_generator<T>{seed});
Paul's avatar
Paul committed
85
86
87
    return result;
}

Paul's avatar
Paul committed
88
template <class T>
89
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
Paul's avatar
Paul committed
90
{
91
92
    auto result = make_shared_array<T>(s.elements());
    std::generate(result.get(), result.get() + s.elements(), [=] { return value; });
Paul's avatar
Paul committed
93
94
95
96
97
    return result;
}

argument fill_argument(shape s, unsigned long value = 0);

Paul's avatar
Paul committed
98
argument generate_argument(shape s, unsigned long seed = 0);
Paul's avatar
Paul committed
99

Paul's avatar
Paul committed
100
101
102
literal generate_literal(shape s, unsigned long seed = 0);

literal abs(literal l);
Paul's avatar
Paul committed
103

Paul's avatar
Paul committed
104
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
105
} // namespace migraphx
Paul's avatar
Paul committed
106
107

#endif