pre_scheduling_impl.hpp 5.26 KB
Newer Older
mei-ye's avatar
mei-ye committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#ifndef MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_IMPL_HPP
#include <migraphx/common_header.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/insert_instruction.hpp>

namespace migraphx {

struct dag_node
{
    dag_node()
    {
        weight         = 0;
        run_on_cpu     = 0;
        weight_sum     = 0;
        ins_ndx        = -1;
        first_child    = nullptr;
        stream         = -1;
        partition      = -1;
        sched_cycle    = -1;
        earliest_cycle = -1;
    }
    int weight;
    int run_on_cpu;
    int weight_sum;
    int ins_ndx;
    dag_node* first_child;
    int stream;
    int partition;
    int sched_cycle;
    int earliest_cycle = -1;
    instruction_ref ins;
    bool is_literal() const { return (ins->name() == "@literal"); }
    bool can_use_stream() const { return (run_on_cpu == 0); }

#ifdef MIGRAPHX_DEBUG_OPT
    void dump();
#endif
};

struct dag_partition
{
    dag_partition()
    {
        num_of_partition = 0;
        weight_sum.clear();
    }

    int create_partition()
    {
        weight_sum.push_back(0);
        return num_of_partition++;
    }
    void add_weight(dag_node* node)
    {
        if(node->partition >= 0)
        {
            assert(node->partition < num_of_partition);
            weight_sum[node->partition] += node->weight;
        }
    }

    int num_of_partition;
    std::vector<int> weight_sum;
};

struct stream_info
{
    stream_info(int n) : num_of_streams(n)
    {
        max_cycle = 0;
        next_cycles.clear();
        for(auto stream = 0; stream < num_of_streams; ++stream)
            next_cycles.push_back(0);
    }
    std::vector<int> next_cycles;
    int num_of_streams;
    int max_cycle;
};

82
83
84
85
86
87
enum instruction_mask : unsigned int
{
    record_event = 0,
    wait_event   = 1
};

mei-ye's avatar
mei-ye committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
struct pre_scheduling_impl
{
    pre_scheduling_impl(program* p,
                        std::function<std::pair<int, int>(const operation&)> w,
                        int n,
                        insert_instruction ins,
                        bool v)
        : p_program(p),
          weight_func(std::move(w)),
          num_of_streams(n),
          insert_instr(std::move(ins)),
          enable_verify(v)
    {
        instr2_node.clear();
        instr2_mask.clear();
        instr2_stream.clear();
    }
    void schedule(std::list<dag_node*>&);
    void compute_weights();
107
    int get_stream(stream_info&, dag_node*) const;
mei-ye's avatar
mei-ye committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    void record(stream_info&, dag_node*);
    void reorder();
    void run();
    void splice(std::list<dag_node*>&);
    void annotate(std::list<dag_node*>&);
    static bool compare_exit_nodes(dag_node* d1, dag_node* d2)
    {
        return (d1->weight_sum > d2->weight_sum);
    }

    struct weighted_topology_ordering
    {
        bool operator()(const dag_node* d1, const dag_node* d2) const
        {
            if(d1->weight_sum < d2->weight_sum)
            {
                // smaller weigth_sum is placed on top of the queue.
                return false;
            }
            else if(d1->weight_sum > d2->weight_sum)
            {
                return true;
            }
            else
            {
                // smaller instrution index is placed on top of the queue,
                return d1->ins_ndx > d2->ins_ndx;
            }
        }
    };

    struct post_schedule_ordering
    {
        bool operator()(const dag_node* d1, const dag_node* d2) const
        {
            if(d1->sched_cycle == d2->sched_cycle)
            {

                if(d1->stream == d2->stream)
                {
                    // smaller instruction index on top of queue.
                    return d1->ins_ndx > d2->ins_ndx;
                }
                else
                {
                    // smaller stream on top of queue.
                    return (d1->stream > d2->stream);
                }
            }
            else
            {
                // smaller sched_cycle on top of queue.
                return (d1->sched_cycle > d2->sched_cycle);
            }
        }
    };

    bool has_mask(instruction_ref ins, unsigned int m)
    {
        if(instr2_mask.find(ins) != instr2_mask.end())
        {
            unsigned int mask = instr2_mask[ins];
            return ((mask & (1u << m)) != 0);
        }
        return false;
    }

    void add_mask(instruction_ref ins, unsigned int m)
    {
        unsigned int mask = (instr2_mask.find(ins) != instr2_mask.end()) ? instr2_mask[ins] : 0;
        if((mask & (1u << m)) == 0)
            instr2_mask[ins] = (mask + (1u << m));
    }
    void verify();

#ifdef MIGRAPHX_DEBUG_OPT
    void dump(const std::string&);
    void dump_program();
    void dump(std::list<dag_node*>&);
#endif
    static const int min_partition_threshold = 2;

    private:
    program* p_program;
    std::function<std::pair<int, int>(const operation&)> weight_func;
    int num_of_streams;
    insert_instruction insert_instr;
    std::vector<dag_node> nodes;
    std::vector<dag_node*> exit_nodes;
    std::unordered_map<instruction_ref, dag_node*> instr2_node;
    std::unordered_map<instruction_ref, int> instr2_stream;
    std::unordered_map<instruction_ref, unsigned int> instr2_mask;
    dag_partition partition_info;
    bool enable_verify;
};
} // namespace migraphx
#endif