Commit 6b4175e8 authored by Paul's avatar Paul
Browse files

Add inception like test

parent cec7544c
......@@ -36,6 +36,11 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace detail {
inline void stream_write_value_impl(rank<2>, std::ostream& os, const std::string& x)
{
os << x;
}
template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void())
......@@ -53,7 +58,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
template <class T>
void stream_write_value(std::ostream& os, const T& x)
{
detail::stream_write_value_impl(rank<1>{}, os, x);
detail::stream_write_value_impl(rank<2>{}, os, x);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -54,7 +54,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}},
// memory_coloring{"hip::allocate"},
memory_coloring{"hip::allocate"},
dead_code_elimination{},
// eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
......
......@@ -30,6 +30,12 @@ struct unary_op
struct nary_op
{
std::string comment = "";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.comment, "comment"));
}
std::string name() const { return "nary"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
......@@ -119,6 +125,15 @@ struct schedule_target
std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
std::vector<std::size_t> get_streams(std::vector<migraphx::instruction_ref> inss)
{
std::vector<std::size_t> result;
std::transform(inss.begin(), inss.end(), std::back_inserter(result), [&](auto ins) {
return this->get_stream(ins);
});
return result;
}
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
};
......@@ -565,4 +580,125 @@ TEST_CASE(par_merge_multi_entry)
EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inception1)
{
schedule_target t{};
migraphx::program p;
auto i1 = p.add_literal(0);
auto i2 = p.add_literal(1);
auto i3 = p.add_literal(1);
auto i4 = p.add_literal(2);
auto i7 = p.add_instruction(nary_op{"i7"}, i1,i4,i3,i2);
auto i8 = p.add_literal(2);
auto i9 = p.add_instruction(migraphx::op::identity{}, i8);
auto i10 = p.add_literal(1);
auto i11 = p.add_instruction(nary_op{"i11"}, i7,i9,i10);
auto i12 = p.add_literal(2);
auto i13 = p.add_instruction(migraphx::op::identity{}, i12);
auto i14 = p.add_literal(1);
auto i15 = p.add_literal(1);
auto i16 = p.add_literal(2);
auto i17 = p.add_instruction(nary_op{"i17"}, i11,i16,i15,i13,i14);
auto i18 = p.add_literal(2);
auto i19 = p.add_instruction(migraphx::op::identity{}, i18);
auto i20 = p.add_literal(1);
auto i21 = p.add_literal(1);
auto i22 = p.add_literal(2);
auto i23 = p.add_instruction(nary_op{"i23"}, i17,i22,i21,i19,i20);
auto i24 = p.add_literal(1);
auto i25 = p.add_instruction(nary_op{"i25"}, i23,i24);
auto i26 = p.add_literal(2);
auto i27 = p.add_instruction(migraphx::op::identity{}, i26);
auto i28 = p.add_literal(1);
auto i29 = p.add_literal(1);
auto i30 = p.add_literal(2);
auto i31 = p.add_instruction(nary_op{"i31"}, i25,i30,i29,i27,i28);
auto i32 = p.add_literal(2);
auto i33 = p.add_instruction(migraphx::op::identity{}, i32);
auto i34 = p.add_literal(1);
auto i35 = p.add_literal(1);
auto i36 = p.add_literal(2);
auto i37 = p.add_instruction(nary_op{"i37"}, i31,i36,i35,i33,i34);
auto i38 = p.add_literal(1);
auto i39 = p.add_instruction(nary_op{"i39"}, i37,i38);
auto i41 = p.add_literal(2);
auto i42 = p.add_instruction(migraphx::op::identity{}, i41);
auto i43 = p.add_literal(1);
auto i44 = p.add_literal(1);
auto i45 = p.add_literal(2);
auto i48 = p.add_instruction(nary_op{"i48"}, i39,i45,i44,i42,i43);
auto i49 = p.add_literal(2);
auto i50 = p.add_instruction(migraphx::op::identity{}, i49);
auto i51 = p.add_literal(1);
auto i52 = p.add_literal(1);
auto i53 = p.add_literal(2);
auto i54 = p.add_instruction(nary_op{"i54"}, i48,i53,i52,i50,i51);
auto i55 = p.add_literal(1);
auto i56 = p.add_instruction(migraphx::op::identity{}, i55);
auto i57 = p.add_literal(2);
auto i58 = p.add_instruction(migraphx::op::identity{}, i57);
auto i59 = p.add_literal(1);
auto i60 = p.add_literal(2);
auto i61 = p.add_instruction(nary_op{"i61"}, i54,i60,i59,i58,i56);
auto i62 = p.add_literal(2);
auto i63 = p.add_instruction(migraphx::op::identity{}, i62);
auto i64 = p.add_literal(1);
auto i65 = p.add_literal(1);
auto i66 = p.add_literal(2);
auto i69 = p.add_instruction(nary_op{"i69"}, i39,i66,i65,i63,i64);
auto i70 = p.add_instruction(migraphx::op::identity{}, i55);
auto i71 = p.add_literal(2);
auto i72 = p.add_instruction(migraphx::op::identity{}, i71);
auto i73 = p.add_literal(1);
auto i74 = p.add_literal(2);
auto i75 = p.add_instruction(nary_op{"i75"}, i69,i74,i73,i72,i70);
auto i77 = p.add_literal(1);
auto i80 = p.add_instruction(nary_op{"i80"}, i39,i77);
auto i81 = p.add_instruction(migraphx::op::identity{}, i55);
auto i82 = p.add_literal(2);
auto i83 = p.add_instruction(migraphx::op::identity{}, i82);
auto i84 = p.add_literal(1);
auto i85 = p.add_literal(2);
auto i86 = p.add_instruction(nary_op{"i86"}, i80,i85,i84,i83,i81);
auto i88 = p.add_instruction(migraphx::op::identity{}, i55);
auto i89 = p.add_literal(2);
auto i90 = p.add_instruction(migraphx::op::identity{}, i89);
auto i91 = p.add_literal(1);
auto i92 = p.add_literal(2);
auto i94 = p.add_instruction(nary_op{"i94"}, i39,i92,i91,i90,i88);
auto i96 = p.add_instruction(migraphx::op::identity{}, i55,i94,i75,i61,i86);
auto i97 = p.add_literal(2);
auto i98 = p.add_instruction(migraphx::op::identity{}, i97);
auto i99 = p.add_literal(3);
auto i100 = p.add_literal(1);
auto i101 = p.add_literal(2);
auto output = p.add_instruction(nary_op{"output"}, i96,i101,i100,i98,i99);
p.compile(t);
EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39, i94}) == t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7, i7}));
EXPECT(t.get_streams({i48, i54, i61, output}) == t.get_streams({output, output, output, output}));
EXPECT(t.get_streams({i80, i86}) == t.get_streams({i80, i80}));
EXPECT(t.get_streams({i69, i75}) == t.get_streams({i69, i69}));
EXPECT(t.get_stream(i7) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i7));
EXPECT(t.get_stream(output) != t.get_stream(i7));
EXPECT(t.get_stream(output) != t.get_stream(i69));
EXPECT(t.get_stream(output) != t.get_stream(i80));
EXPECT(get_wait_for(i48) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i80) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i69) == get_wait_for({t.get_stream(i39)}));
// We dont wait twice
EXPECT(get_wait_for(i94).empty());
EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output), {t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)}));
check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61, output}, {i94}});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment