generate.cpp 1.42 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
2

Paul's avatar
Paul committed
3
namespace migraphx {
Paul's avatar
Paul committed
4
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
5

Paul's avatar
Paul committed
6
7
8
9
10
11
argument fill_argument(shape s, unsigned long value)
{
    argument result;
    s.visit_type([&](auto as) {
        using type = typename decltype(as)::type;
        auto v     = fill_tensor_data<type>(s, value);
12
        result     = {s, v};
Paul's avatar
Paul committed
13
14
15
16
    });
    return result;
}

Paul's avatar
Paul committed
17
argument generate_argument(shape s, unsigned long seed)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    argument result;
Paul's avatar
Paul committed
20
    s.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
21
22
23
24
25
26
27
28
29
30
31
32
33
        // we use char type to store bool type internally, so bool_type
        // needs special processing to generate data
        if(s.type() == shape::bool_type)
        {
            auto v = generate_tensor_data<bool>(s, seed);
            result = {s, v};
        }
        else
        {
            using type = typename decltype(as)::type;
            auto v     = generate_tensor_data<type>(s, seed);
            result     = {s, v};
        }
Paul's avatar
Paul committed
34
35
36
37
    });
    return result;
}

Paul's avatar
Paul committed
38
literal generate_literal(shape s, unsigned long seed)
Paul's avatar
Paul committed
39
40
41
42
43
{
    literal result;
    s.visit_type([&](auto as) {
        using type = typename decltype(as)::type;
        auto v     = generate_tensor_data<type>(s, seed);
44
        result     = {s, reinterpret_cast<char*>(v.get())};
Paul's avatar
Paul committed
45
46
47
48
    });
    return result;
}

Paul's avatar
Paul committed
49
50
51
// TODO: Move to literal.cpp
literal abs(literal l)
{
Paul's avatar
Paul committed
52
    return transform(std::move(l), [](auto x) { return std::fabs(x); });
Paul's avatar
Paul committed
53
54
}

Paul's avatar
Paul committed
55
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
56
} // namespace migraphx