#include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { auto get_inputs() { return [](auto i) { return i->inputs(); }; } auto get_outputs() { return [](auto i) { return i->outputs(); }; } struct stream_info { std::unordered_map ins2stream; std::unordered_map weights; std::unordered_map iweights; void accumulate_weights(instruction_ref last, const schedule_model& model) { fix([&](auto self, auto ins) -> std::size_t { if(not contains(weights, ins)) { std::size_t weight = 0; auto&& op = ins->get_operator(); if(not is_context_free(op) and op.name()[0] != '@') weight = model.weight(op); iweights[ins] = weight; weights[ins] = std::accumulate(ins->inputs().begin(), ins->inputs().end(), weight, [&](std::size_t w, instruction_ref i) { return w + self(i); }); } return weights[ins]; })(last); } std::vector::iterator sort_args(std::vector& args) { const std::size_t min_partition_threshold = 2; auto compare = by(std::less<>{}, [&](auto x) { return std::make_tuple(this->weights[x], x->inputs().size()); }); if (args.size() < 2) { return args.end(); } else if (args.size() == 2) { auto w1 = this->weights[args[0]]; auto w2 = this->weights[args[1]]; if (std::make_tuple(w1, args[0]->inputs().size()) > std::make_tuple(w2, args[1]->inputs().size())) { std::swap(args[0], args[1]); std::swap(w1, w2); } if (w1 > min_partition_threshold) return args.begin(); if (w2 > min_partition_threshold) return args.begin()+1; return args.end(); } std::sort(args.begin(), args.end(), compare); return std::upper_bound(args.begin(), args.end(), min_partition_threshold, [&](std::size_t w, auto i) { return w < this->weights[i]; }); } struct partition { std::size_t weight = 0; std::vector instructions{}; void add(instruction_ref ins, std::size_t w) { weight += w; instructions.push_back(ins); } }; void assign_streams(program& p, std::size_t n) { partition critical; std::unordered_map> partitions; partitions.reserve(weights.size()); fix([&](auto self, auto ins, auto& part) { if (contains(partitions, ins)) return; partitions[ins]; part.add(ins, this->iweights[ins]); auto args = ins->inputs(); auto threshold_it = sort_args(args); for(auto i:range(args.begin(), threshold_it)) { self(i, part); } for(auto i:range(threshold_it, args.end())) { if (i == args.back()) { self(i, part); } else { partitions[ins].emplace_back(); self(i, partitions[ins].back()); } } // Sort instructions p.move_instruction(ins, p.end()); })(std::prev(p.end()), critical); // Set the critical partition to stream 0 set_stream(critical, 0); std::vector streams(n - 1); // Assign streams for the other partitions for(auto&& ins_part : partitions) { std::sort( ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) { return std::make_tuple(x.weight, x.instructions.size()); })); for(auto&& part : ins_part.second) { auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin(); set_stream(part, stream + 1); streams[stream] += part.weight; } } } void set_stream(const partition& p, std::size_t n) { for(auto ins : p.instructions) if(iweights[ins] > 0) set_stream(ins, n); } void set_stream(instruction_ref ins, std::size_t n) { assert(iweights[ins] > 0); ins2stream[ins] = n; } std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); } bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); } template bool different(F f, std::size_t stream) const { bool result = false; f([&](auto s) { if(s != stream) { result = true; return false; } // cppcheck-suppress uselessAssignmentArg stream = s; return true; }); return result; } template bool different(F f) const { bool result = false; f([&](auto s) { result = different(f, s); return false; }); return result; } template auto get_streams_from(instruction_ref start, Selector select) const { return [=](auto f) { return fix([&](auto self, auto ins) { for(auto i : select(ins)) { if(iweights.at(i) == 0) { if(not self(i)) return false; } else { if(not f(get_stream(i))) return false; } } return true; })(start); }; } std::unordered_set get_streams(instruction_ref ins) const { if(has_stream(ins)) return {get_stream(ins)}; std::unordered_set result; get_streams_from(ins, get_inputs())([&](auto s) { result.insert(s); return true; }); return result; } template bool is_merge_point(instruction_ref ins, Ts... xs) const { return different(get_streams_from(ins, get_inputs()), xs...); } template bool is_split_point(instruction_ref ins, Ts... xs) const { return different(get_streams_from(ins, get_outputs()), xs...); } std::vector get_recorded_instructions(instruction_ref start) { std::vector result; std::unordered_map m; fix([&](auto self, auto ins) { for(auto i : ins->inputs()) { if(iweights.at(i) == 0) { self(i); continue; } auto stream = get_stream(i); if(not contains(m, stream)) m[stream] = i; else m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) { return std::distance(x, start); })); } })(start); std::transform( m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; }); return result; } std::unordered_map>> find_concurrent_instructions(program& p) { std::unordered_map>> result; std::unordered_map> merge_from; result.reserve(p.size()); merge_from.reserve(p.size()); for(auto ins : reverse_iterator_for(p)) { for(auto&& arg : ins->outputs()) { if(is_merge_point(arg)) merge_from[ins].insert(arg); merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end()); } auto streams = get_streams(ins); // Collect concur instructions for each merge point. for(auto& merge : merge_from[ins]) { for(auto stream : streams) { if(result[merge].size() <= stream) result[merge].resize(stream + 1); auto&& r = result[merge][stream]; r.push_back(ins); // Copy inputs if they dont have a stream(and are not a builtin and context // free). Inputs without a stream can have a implicit dependency std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(r), [&](auto x) { return not this->has_stream(x) and not is_context_free(x->get_operator()) and x->name().front() != '@'; }); } } } return result; } }; void schedule::apply(program& p) const { stream_info si; auto last = std::prev(p.end()); si.accumulate_weights(last, model); si.assign_streams(p, model.concurrency()); if(enabled(MIGRAPHX_TRACE_COMPILE{})) { p.annotate(std::cout, [&](auto ins) { std::cout << ":"; std::cout << " weight=" << si.weights.at(ins); std::cout << " input={"; si.get_streams_from(ins, get_inputs())([&](auto s) { std::cout << s << ","; return true; }); std::cout << "}"; if(si.has_stream(ins)) std::cout << " stream=" << si.get_stream(ins); }); std::cout << std::endl; } // Schedule instructions std::size_t wait_id = 0; std::unordered_map ins2wait; std::unordered_map> waited_for; std::unordered_map> ins2waited; ins2wait.reserve(p.size()); ins2waited.reserve(p.size()); for(auto ins : iterator_for(p)) { // Only schedule instructions that have a stream if(not si.has_stream(ins)) continue; assert(si.weights[ins] > 0); // Schedule instruction on the stream auto stream = si.get_stream(ins); assert(stream < model.concurrency()); model.sched(p, ins, stream); // Insert wait instructions if(si.is_merge_point(ins, stream)) { for(auto i : si.get_recorded_instructions(ins)) { if(not si.has_stream(i)) continue; auto istream = si.get_stream(i); if(stream == istream) continue; // Create a new event if it hasn't been recorded if(not contains(ins2wait, i)) { ins2wait[i] = wait_id; model.record(p, i, wait_id); wait_id++; } auto w = ins2wait.at(i); // If we already waited for the event on this stream then dont // insert another wait event if(not contains(waited_for[stream], w)) model.wait(p, ins, w); // Store the event as waited waited_for[stream].insert(w); // Store all wait events that have been waited on prior to the recorded instruction waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end()); } } // Store wait events that have already been waited on if(si.is_split_point(ins, stream)) { ins2waited[ins] = waited_for[stream]; } } // Add memory conflicts auto concur_ins = si.find_concurrent_instructions(p); for(auto&& merge : concur_ins) { dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) { if(i == j) return; if (merge.second[i].empty()) return; if (merge.second[j].empty()) return; for(auto ins1 : merge.second[i]) { auto args = merge.second[j]; args.insert(args.begin(), ins1); p.insert_instruction(merge.first, op::identity{}, args); } }); } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx