generate.cpp 2.19 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
2

Paul's avatar
Paul committed
3
namespace migraphx {
Paul's avatar
Paul committed
4
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
5

Paul's avatar
Paul committed
6
7
8
argument fill_argument(shape s, unsigned long value)
{
    argument result;
Shucai Xiao's avatar
Shucai Xiao committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    if(s.type() == shape::tuple_type)
    {
        std::vector<argument> sub_args;
        const auto& sub_ss = s.sub_shapes();
        std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
            return fill_argument(ss, value);
        });

        result = argument(sub_args);
    }
    else
    {
        s.visit_type([&](auto as) {
            using type = typename decltype(as)::type;
            auto v     = fill_tensor_data<type>(s, value);
            result     = {s, v};
        });
    }
Paul's avatar
Paul committed
27
28
29
    return result;
}

Paul's avatar
Paul committed
30
argument generate_argument(shape s, unsigned long seed)
Paul's avatar
Paul committed
31
{
Paul's avatar
Paul committed
32
    argument result;
Shucai Xiao's avatar
Shucai Xiao committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    if(s.type() == shape::tuple_type)
    {
        const auto& sub_ss = s.sub_shapes();
        std::vector<argument> sub_args;
        std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
            return generate_argument(ss, seed);
        });

        result = argument(sub_args);
    }
    else
    {
        s.visit_type([&](auto as) {
            // we use char type to store bool type internally, so bool_type
            // needs special processing to generate data
            if(s.type() == shape::bool_type)
            {
                auto v = generate_tensor_data<bool>(s, seed);
                result = {s, v};
            }
            else
            {
                using type = typename decltype(as)::type;
                auto v     = generate_tensor_data<type>(s, seed);
                result     = {s, v};
            }
        });
    }

Paul's avatar
Paul committed
62
63
64
    return result;
}

Paul's avatar
Paul committed
65
literal generate_literal(shape s, unsigned long seed)
Paul's avatar
Paul committed
66
67
68
69
70
{
    literal result;
    s.visit_type([&](auto as) {
        using type = typename decltype(as)::type;
        auto v     = generate_tensor_data<type>(s, seed);
71
        result     = {s, reinterpret_cast<char*>(v.get())};
Paul's avatar
Paul committed
72
73
74
75
    });
    return result;
}

Paul's avatar
Paul committed
76
77
78
// TODO: Move to literal.cpp
literal abs(literal l)
{
Paul's avatar
Paul committed
79
    return transform(std::move(l), [](auto x) { return std::fabs(x); });
Paul's avatar
Paul committed
80
81
}

Paul's avatar
Paul committed
82
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
83
} // namespace migraphx