schedule_model.cpp 3.17 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/schedule_model.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
Paul's avatar
Paul committed
3
4
5
6
7
8
9
#include <migraphx/program.hpp>
#include <migraphx/operation.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

Paul's avatar
Paul committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);

hip_event_ptr create_event()
{
    hipEvent_t event;
    auto status = hipEventCreateWithFlags(&event, hipEventDisableTiming);
    if(status != hipSuccess)
        MIGRAPHX_THROW("Failed to create event");
    return hip_event_ptr{event};
}

struct wait_event
{
    std::vector<std::size_t> wait_for;
    shared<hip_event_ptr> event;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.wait_for, "wait_for"));
    }
    std::string name() const { return "gpu::wait_event"; }
    shape compute_shape(const std::vector<shape>&) const { return {}; }

    argument compute(context& ctx, const shape&, const std::vector<argument>&) const
    {
        assert(event != nullptr);
Paul's avatar
Paul committed
36
        for(auto n : wait_for)
Paul's avatar
Paul committed
37
38
39
40
41
            ctx.get_stream(n).record(event.get());
        ctx.get_stream().wait(event.get());
        return {};
    }

Paul's avatar
Paul committed
42
    void finalize(context& ctx, const shape&, std::vector<shape>) { event = create_event(); }
Paul's avatar
Paul committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
};

struct set_stream
{
    std::size_t stream = 0;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.stream, "stream"));
    }
    std::string name() const { return "gpu::set_stream"; }
    shape compute_shape(const std::vector<shape>&) const { return {}; }

    argument compute(context& ctx, const shape&, const std::vector<argument>&) const
    {
        assert(stream >= 0);
        ctx.set_stream(stream);
        return {};
    }
    void finalize(context& ctx, const shape&, const std::vector<shape>&) { ctx.set_stream(stream); }
};

65
std::size_t schedule_model::concurrency() const { return streams; }
Paul's avatar
Paul committed
66
void schedule_model::schedule_instruction(program& p, instruction_ref ins, std::size_t n) const
Paul's avatar
Paul committed
67
68
69
{
    p.insert_instruction(ins, set_stream{n});
}
Paul's avatar
Paul committed
70
71
72
73
void schedule_model::wait(program& p,
                          instruction_ref ins,
                          std::size_t wait_on,
                          const std::vector<std::size_t>& wait_for) const
Paul's avatar
Paul committed
74
{
Paul's avatar
Paul committed
75
76
77
78
79
80
81
    p.insert_instruction(ins, set_stream{wait_on});
    p.insert_instruction(ins, wait_event{wait_for});
}

static std::unordered_map<std::string, std::size_t> create_weight_map()
{
    return {
Paul's avatar
Paul committed
82
        {"hip::load_literal", 0},
Paul's avatar
Paul committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        {"hip::allocate", 0},
        {"gpu::convolution", 4},
        {"gpu::conv_bias_relu", 4},
        {"gpu::pooling", 2},
        {"gpu::gemm", 2},
        {"gpu::concat", 1},
        {"hip::add_relu", 2},
    };
}

static const std::unordered_map<std::string, std::size_t>& weight_map()
{
    static std::unordered_map<std::string, std::size_t> m = create_weight_map();
    return m;
}

Paul's avatar
Paul committed
99
std::size_t schedule_model::weight(const operation& op) const
Paul's avatar
Paul committed
100
{
Paul's avatar
Paul committed
101
    if(weight_map().count(op.name()) == 0)
102
    {
Paul's avatar
Paul committed
103
        if(is_context_free(op) or op.name()[0] == '@')
104
105
106
            return 0;
        return 1;
    }
Paul's avatar
Paul committed
107
    return weight_map().at(op.name());
Paul's avatar
Paul committed
108
109
110
111
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
112
} // namespace migraphx