#include #include #include #include #include #include #include #include struct unary_op { std::string name() const { return "unary"; } migraphx::argument compute(migraphx::context&, const migraphx::shape&, std::vector args) const { if(args.empty()) return {}; return args.front(); } migraphx::shape compute_shape(std::vector inputs) const { if(inputs.empty()) return {}; return inputs.front(); } int output_alias(const std::vector&) const { return 0; } }; struct nary_op { std::string name() const { return "nary"; } migraphx::argument compute(migraphx::context&, const migraphx::shape&, std::vector args) const { if(args.empty()) return {}; return args.front(); } migraphx::shape compute_shape(std::vector inputs) const { if(inputs.empty()) return {}; return inputs.front(); } }; struct wait_event { std::vector wait_for; template static auto reflect(Self& self, F f) { return migraphx::pack(f(self.wait_for, "wait_for")); } std::string name() const { return "wait_event"; } migraphx::shape compute_shape(const std::vector&) const { return {}; } migraphx::argument compute(migraphx::context&, const migraphx::shape&, const std::vector&) const { return {}; } }; using instruction_map = std::unordered_map; struct schedule_model_test { instruction_map* ins2stream; std::size_t concurrency() const { return 4; } void schedule_instruction(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const { (*ins2stream)[ins] = n; } void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_on, const std::vector& wait_for) const { (*ins2stream)[ins] = wait_on; p.insert_instruction(ins, wait_event{wait_for}); } std::size_t weight(const migraphx::operation& op) const { if(op.name() == "binary" or op.name() == "unary") return 4; else return 1; } }; struct schedule_target { instruction_map* ins2stream; std::string name() const { return "schedule"; } std::vector get_passes(migraphx::context&) const { return {migraphx::schedule{schedule_model_test{ins2stream}}}; } migraphx::context get_context() const { return {}; } }; bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y) { for(auto ins : migraphx::iterator_for(p)) { if(ins->name() != "identity") continue; if(ins->inputs().size() != 2) continue; if(ins->inputs() == std::vector{x, y}) return true; if(ins->inputs() == std::vector{y, x}) return true; } return false; } void check_conflicts(migraphx::program& p, std::vector> conflicts) { migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) { if(i == j) return; for(auto ins1 : conflicts[i]) for(auto ins2 : conflicts[j]) CHECK(check_conflicts(p, ins1, ins2)); }); } std::vector get_wait_for(std::size_t wait_on, std::vector wait_for) { wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on)); std::sort(wait_for.begin(), wait_for.end()); return wait_for; } std::vector get_wait_for(migraphx::instruction_ref ins) { auto wait_ins = std::prev(ins); if(wait_ins->name() != "wait_event") return {}; auto wf = migraphx::any_cast(wait_ins->get_operator()).wait_for; std::sort(wf.begin(), wf.end()); return wf; } template std::vector chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input) { std::vector result; for(std::size_t i = 0; i < n; i++) { result.push_back(p.add_instruction(x, input)); input = result.back(); } return result; } TEST_CASE(single_entry) { instruction_map stream; migraphx::program p; auto one = p.add_literal(1); auto onep1 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one); auto binary = p.add_instruction(nary_op{}, onep1, onep2); p.compile(schedule_target{&stream}); EXPECT(stream.count(one) == 0); EXPECT(stream.at(onep1) != stream.at(onep2)); EXPECT(stream.at(binary) == 0); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep1], stream[onep2]})); EXPECT(check_conflicts(p, onep1, onep2)); } TEST_CASE(double_entry) { instruction_map stream; migraphx::program p; auto one = p.add_literal(1); auto two = p.add_literal(2); auto onep = p.add_instruction(unary_op{}, one); auto twop = p.add_instruction(unary_op{}, two); auto binary = p.add_instruction(nary_op{}, onep, twop); p.compile(schedule_target{&stream}); EXPECT(stream.count(one) == 0); EXPECT(stream.count(two) == 0); EXPECT(stream.at(onep) != stream.at(twop)); EXPECT(stream.at(binary) == 0); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep], stream[twop]})); // EXPECT(check_conflicts(p, onep, twop)); } TEST_CASE(two_weights) { instruction_map stream; migraphx::program p; auto one = p.add_literal(1); auto c1 = chain(p, 2, unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one); auto binary = p.add_instruction(nary_op{}, i1, c1.back()); p.compile(schedule_target{&stream}); EXPECT(stream.count(one) == 0); EXPECT(stream.at(i1) == 1); for(auto ins : c1) EXPECT(stream.at(ins) == 0); EXPECT(stream.at(binary) == 0); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]})); check_conflicts(p, {c1, {i1}}); } TEST_CASE(four_weights) { instruction_map stream; migraphx::program p; auto one = p.add_literal(1); auto c1 = chain(p, 4, unary_op{}, one); auto c2 = chain(p, 3, unary_op{}, one); auto c3 = chain(p, 2, unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one); auto binary = p.add_instruction(nary_op{}, i1, c1.back()); p.compile(schedule_target{&stream}); EXPECT(stream.count(one) == 0); EXPECT(stream.at(i1) == 3); for(auto ins : c1) EXPECT(stream.at(ins) == 0); for(auto ins : c2) EXPECT(stream.at(ins) == 1); for(auto ins : c3) EXPECT(stream.at(ins) == 2); EXPECT(stream.at(binary) == 0); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]})); check_conflicts(p, {c1, c2, c3, {i1}}); } int main(int argc, const char* argv[]) { test::run(argc, argv); }