#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SCHEDULE) auto get_inputs() { return [](auto i) { return i->inputs(); }; } auto get_outputs() { return [](auto i) { return i->outputs(); }; } struct stream_info { std::unordered_map ins2stream; std::unordered_map weights; std::unordered_map iweights; ins_dep_map mod_implicit_deps; void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); } void accumulate_weights(instruction_ref last, const schedule_model& model) { fix([&](auto self, auto ins) -> std::size_t { if(not contains(weights, ins)) { std::size_t weight = 0; auto&& op = ins->get_operator(); if(not is_context_free(op) and op.name()[0] != '@') weight = model.weight(op); // This will ensure a stream will be assigned to return if(op.name() == "@return") weight = 1; iweights[ins] = weight; auto inputs = ins->inputs(); if(contains(mod_implicit_deps, ins)) { const auto& impl_deps = mod_implicit_deps.at(ins); inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end()); } weights[ins] = std::accumulate( inputs.begin(), inputs.end(), weight, [&](std::size_t w, instruction_ref i) { return w + self(i); }); } return weights[ins]; })(last); } template void sort_args_by_weight(std::vector& args, Compare compare) const { if(args.size() < 2) return; std::sort(args.begin(), args.end(), by(compare, [this](auto x) { return std::make_tuple( this->weights.at(x), x->inputs().size(), std::addressof(*x)); })); } std::vector::iterator sort_args(std::vector& args) { if(args.size() < 2) { return args.end(); } const std::size_t min_partition_threshold = 2; sort_args_by_weight(args, std::greater<>{}); auto it = std::lower_bound(std::next(args.begin()), args.end(), min_partition_threshold, [&](auto i, std::size_t w) { return this->weights[i] > w; }); assert(it == args.end() or this->weights[*it] <= min_partition_threshold); assert(it == args.end() or std::prev(it) == args.begin() or this->weights[*std::prev(it)] > min_partition_threshold); return it; } struct partition { std::size_t weight = 0; std::vector instructions{}; void add(instruction_ref ins, std::size_t w) { weight += w; instructions.push_back(ins); } }; std::size_t assign_streams(module& p, std::size_t n) { assert(n > 0); partition critical; std::unordered_map> partitions; partitions.reserve(weights.size()); fix([&](auto self, auto ins, auto& part) { assert(ins != p.end()); if(contains(partitions, ins)) return; if(not p.has_instruction(ins)) return; // Add an entry so we know the instruction was visited partitions[ins]; part.add(ins, this->iweights[ins]); auto args = ins->inputs(); auto threshold_it = this->sort_args(args); if(not args.empty()) { assert(threshold_it != args.begin()); self(args.front(), part); for(auto i : range(std::next(args.begin()), threshold_it)) { partitions[ins].emplace_back(); self(i, partitions[ins].back()); } for(auto i : range(threshold_it, args.end())) { self(i, part); } } // Sort instructions p.move_instruction(ins, p.end()); })(std::prev(p.end()), critical); // Set the critical partition to stream 0 set_stream(critical, 0); if(n == 1) { // Assign streams for the other partitions for(auto&& ins_part : partitions) for(auto&& part : ins_part.second) set_stream(part, 0); return 1; } else { std::vector streams(n - 1); // Assign streams for the other partitions for(auto&& ins_part : partitions) { std::sort(ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) { return std::make_tuple(x.weight, x.instructions.size()); })); for(auto&& part : ins_part.second) { auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin(); set_stream(part, stream + 1); streams[stream] += part.weight; } } return 1 + std::count_if(streams.begin(), streams.end(), [](auto x) { return x > 0; }); } } using weight_ins = std::pair; struct compare_weight_ins { bool operator()(const weight_ins& x, const weight_ins& y) const { return std::make_pair(x.first, std::addressof(*x.second)) < std::make_pair(y.first, std::addressof(*y.second)); } }; void sort(module& p, std::size_t) { std::set children; std::unordered_map visited; auto last = std::prev(p.end()); auto mw = this->weights.at(last); auto nw = mw / (p.size() + 1); auto add_child = [&](auto ins) { auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1); auto w = x * this->iweights.at(ins); auto& v = visited[ins]; auto it = children.find(std::make_pair(v * w, ins)); if(it == children.end()) { v++; children.insert(std::make_pair(v * w, ins)); } }; add_child(last); while(not children.empty()) { // Pop the first element auto top = children.begin()->second; children.erase(children.begin()); p.move_instruction(top, p.begin()); for(auto ins : top->inputs()) { if(not p.has_instruction(ins)) continue; add_child(ins); } if(contains(mod_implicit_deps, top)) { for(auto ins : mod_implicit_deps.at(top)) { assert(p.has_instruction(ins)); add_child(ins); } } } } void set_stream(const partition& p, std::size_t n) { for(auto ins : p.instructions) if(iweights[ins] > 0) set_stream(ins, n); } void set_stream(instruction_ref ins, std::size_t n) { assert(iweights[ins] > 0); ins2stream[ins] = n; } std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); } bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); } template bool different(F f, std::size_t stream) const { bool result = false; f([&](auto s) { if(s != stream) { result = true; return false; } // cppcheck-suppress uselessAssignmentArg stream = s; return true; }); return result; } template bool different(F f) const { bool result = false; f([&](auto s) { result = this->different(f, s); return false; }); return result; } template auto get_streams_from(instruction_ref start, Selector select) const { return [=](auto f) { return fix([&](auto self, auto ins) { return all_of(select(ins), [&](auto i) { if(iweights.at(i) == 0) return self(i); else return f(this->get_stream(i)); }); })(start); }; } std::unordered_set get_streams(instruction_ref ins) const { if(has_stream(ins)) return {get_stream(ins)}; std::unordered_set result; get_streams_from(ins, get_inputs())([&](auto s) { result.insert(s); return true; }); return result; } template bool is_merge_point(instruction_ref ins, Ts... xs) const { return different(get_streams_from(ins, get_inputs()), xs...); } template bool is_split_point(instruction_ref ins, Ts... xs) const { return different(get_streams_from(ins, get_outputs()), xs...); } std::vector get_recorded_instructions(instruction_ref start) { std::vector result; std::unordered_map m; fix([&](auto self, auto ins) { for(auto i : ins->inputs()) { if(iweights.at(i) == 0) { self(i); continue; } auto stream = this->get_stream(i); if(not contains(m, stream)) m[stream] = i; else m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) { return std::distance(x, start); })); } })(start); std::transform( m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; }); return result; } std::unordered_map>> find_concurrent_instructions(module& p) const { std::unordered_map>> result; std::unordered_map> merge_from; dominator_info di = compute_dominator(p); result.reserve(p.size()); merge_from.reserve(p.size()); for(auto ins : reverse_iterator_for(p)) { for(auto&& arg : ins->outputs()) { if(not p.has_instruction(arg)) continue; if(is_merge_point(arg)) merge_from[ins].insert(arg); merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end()); } if(is_split_point(ins)) { erase_if(merge_from[ins], [&](auto merge) { return di.strictly_dominate(ins, merge); }); } auto streams = this->get_streams(ins); // Collect concur instructions for each merge point. for(const auto& merge : merge_from[ins]) { for(auto stream : streams) { if(result[merge].size() <= stream) result[merge].resize(stream + 1); auto&& r = result[merge][stream]; r.push_back(ins); // Copy inputs if they dont have a stream(and are not a builtin and context // free). Inputs without a stream can have a implicit dependency std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(r), [&](auto x) { return not this->has_stream(x) and not is_context_free(x->get_operator()) and x->name().front() != '@'; }); } } } return result; } std::unordered_map> get_conflicts(module& p) { using conflict_table_type = std::unordered_map>; conflict_table_type conflict_table; auto concur_ins = this->find_concurrent_instructions(p); // Compute an index for each instruction std::unordered_map ins2index; std::size_t index_total = 0; for(auto ins : iterator_for(p)) ins2index[ins] = index_total++; std::vector thread_conflict_tables( std::thread::hardware_concurrency()); std::vector index_to_ins; index_to_ins.reserve(concur_ins.size()); std::transform(concur_ins.begin(), concur_ins.end(), std::back_inserter(index_to_ins), [](auto&& it) { return it.first; }); par_for(concur_ins.size(), [&](auto ins_index, auto tid) { auto merge_first = index_to_ins[ins_index]; assert(concur_ins.count(merge_first) > 0); auto& merge_second = concur_ins.at(merge_first); // ensure there are enough elements for different threads assert(tid < thread_conflict_tables.size()); auto& thrd_table = thread_conflict_tables.at(tid); std::unordered_set checked_ins_set; auto range_i = range(merge_second.begin(), std::prev(merge_second.end())); for(auto it_i : iterator_for(range_i)) { std::unordered_set ins1_set; std::copy_if(it_i->begin(), it_i->end(), std::inserter(ins1_set, ins1_set.end()), [&](auto i) { return not contains(checked_ins_set, i); }); checked_ins_set.insert(ins1_set.begin(), ins1_set.end()); auto range_j = range(std::next(it_i), merge_second.end()); std::unordered_set ins2_set; for(auto it_j : iterator_for(range_j)) { std::copy_if(it_j->begin(), it_j->end(), std::inserter(ins2_set, ins2_set.end()), [&](auto i) { return not contains(checked_ins_set, i); }); } for(auto ins1 : ins1_set) { auto p1 = ins2index.at(ins1); for(auto ins2 : ins2_set) { if(ins1 == ins2) continue; auto p2 = ins2index.at(ins2); if(p2 > p1) thrd_table[ins2].insert(ins1); else thrd_table[ins1].insert(ins2); } } } }); // merge thread_conflict_tables together for(auto& tbl : thread_conflict_tables) { for(auto& it : tbl) { conflict_table[it.first].insert(it.second.begin(), it.second.end()); } } // Remove instructions from the conflict table of an ealier instruction for(auto&& ip : conflict_table) { auto ins1 = ip.first; for(auto ins2 : ip.second) if(contains(conflict_table[ins2], ins1)) conflict_table[ins2].erase(ins1); } return conflict_table; } }; void schedule::apply(module& p) const { if(not enable) return; stream_info si; si.calc_implicit_deps(p); auto last = std::prev(p.end()); si.accumulate_weights(last, model); auto nstreams = si.assign_streams(p, model.concurrency()); si.sort(p, model.concurrency()); if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{})) { p.annotate(std::cout, [&](auto ins) { std::cout << ":"; std::cout << " weight=" << si.weights.at(ins); std::cout << " input={"; si.get_streams_from(ins, get_inputs())([&](auto s) { std::cout << s << ","; return true; }); std::cout << "}"; if(si.has_stream(ins)) std::cout << " stream=" << si.get_stream(ins); }); std::cout << std::endl; } // No concurrency if(nstreams < 2) return; // Schedule instructions std::size_t wait_id = 0; std::unordered_map ins2wait; std::unordered_map> waited_for; std::unordered_map> ins2waited; ins2wait.reserve(p.size()); ins2waited.reserve(p.size()); for(auto ins : iterator_for(p)) { // Only schedule instructions that have a stream if(not si.has_stream(ins)) continue; assert(si.weights[ins] > 0); // Schedule instruction on the stream auto stream = si.get_stream(ins); assert(stream < model.concurrency()); model.sched(p, ins, stream); // Insert wait instructions if(si.is_merge_point(ins, stream)) { for(auto i : si.get_recorded_instructions(ins)) { if(not si.has_stream(i)) continue; auto istream = si.get_stream(i); if(stream == istream) continue; // Create a new event if it hasn't been recorded if(not contains(ins2wait, i)) { ins2wait[i] = wait_id; model.record(p, i, wait_id); wait_id++; } auto w = ins2wait.at(i); // If we already waited for the event on this stream then dont // insert another wait event if(not contains(waited_for[stream], w)) model.wait(p, ins, w); // Store the event as waited waited_for[stream].insert(w); // Store all wait events that have been waited on prior to the recorded instruction waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end()); } } // Store wait events that have already been waited on if(si.is_split_point(ins, stream)) { ins2waited[ins] = waited_for[stream]; } } // Add memory conflicts auto conflict_table = si.get_conflicts(p); for(auto&& ip : conflict_table) { if(ip.second.empty()) continue; std::vector args; args.push_back(ip.first); args.insert(args.end(), ip.second.begin(), ip.second.end()); p.insert_instruction(std::next(ip.first), make_op("identity"), args); } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx