"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2d33413cd84da8c68bbc87c38ffa1a1bfb48da6d"
compile_ops.cpp 2.82 KB
Newer Older
1
2
3
4
5
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/par_for.hpp>
7
8
9
10
11
12
13
14
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

15
16
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
struct precompile_op
{
    operation op = op::identity{};

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    std::string name() const { return "gpu::precompile_op"; }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
        inputs.pop_back();
        return op.compute_shape(inputs, mods);
    }

    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
};

MIGRAPHX_REGISTER_OP(precompile_op);

struct pointwise_compiler
{
    std::string name() const { return "pointwise"; }

    operation apply(context& ctx, instruction_ref ins, const operation&) const
    {
        assert(not ins->module_inputs().empty());
        auto* pm = ins->module_inputs().front();
        return compile_pointwise(ctx, to_shapes(ins->inputs()), *pm);
    }
};

using compiler_function = std::function<operation(context&, instruction_ref, operation)>;

template <class T>
compiler_function make_compiler_function(T x)
{
    return {[=](auto&&... xs) { return x.apply(xs...); }};
}

template <class... Ts>
std::unordered_map<std::string, compiler_function> make_compilers(Ts... xs)
{
    return {{xs.name(), make_compiler_function(xs)}...};
}

Paul Fultz II's avatar
Paul Fultz II committed
69
70
71
72
73
74
struct compiled_result
{
    operation op;
    instruction_ref ins;
};

75
76
77
78
79
80
81
82
template <class F>
void par_compile(std::size_t n, F f)
{
    if(n == 0)
        return;
    par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}

83
84
85
void compile_ops::apply(module& m) const
{
    auto compilers = make_compilers(pointwise_compiler{});
Paul Fultz II's avatar
Paul Fultz II committed
86
87
    std::vector<std::function<compiled_result()>> compiles;

88
89
90
91
92
93
    for(auto ins : iterator_for(m))
    {
        if(ins->name() != "gpu::precompile_op")
            continue;
        operation preop = any_cast<precompile_op>(ins->get_operator()).op;
        assert(contains(compilers, preop.name()));
Paul Fultz II's avatar
Paul Fultz II committed
94
95
96
97
        auto c = compilers[preop.name()];
        compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
    }
    std::vector<compiled_result> results(compiles.size());
98
    par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
Paul Fultz II's avatar
Paul Fultz II committed
99
100
101
    for(const auto& cr : results)
    {
        m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
102
103
104
105
106
107
108
    }
}

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx