"docs/vscode:/vscode.git/clone" did not exist on "2340798353bc58398b6d45f582c7c79b670d0256"
Commit 8dda7ad2 authored by Paul's avatar Paul
Browse files

Fix bug in eliminate_concat

parent 027e1fa7
...@@ -13,8 +13,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,8 +13,6 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(program& p) const void eliminate_allocation::apply(program& p) const
{ {
assert(alignment > 0); assert(alignment > 0);
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
return;
std::size_t n = 0; std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
...@@ -27,13 +25,16 @@ void eliminate_allocation::apply(program& p) const ...@@ -27,13 +25,16 @@ void eliminate_allocation::apply(program& p) const
std::size_t padding = (alignment - (size % alignment)) % alignment; std::size_t padding = (alignment - (size % alignment)) % alignment;
n += size + padding; n += size + padding;
} }
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}}); if (n > 0)
for(auto&& pp : allocs)
{ {
auto ins = pp.first; auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
auto s = ins->get_shape(); for(auto&& pp : allocs)
auto offset = pp.second; {
p.replace_instruction(ins, op::load{s, offset}, mem); auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(ins, op::load{s, offset}, mem);
}
} }
} }
......
...@@ -36,14 +36,13 @@ void eliminate_concat::apply(program& p) const ...@@ -36,14 +36,13 @@ void eliminate_concat::apply(program& p) const
// Where are the allocations for the tensors to be concatenated? // Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations; std::vector<instruction_ref> allocations;
for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++) std::transform(ins->inputs().begin(), std::prev(ins->inputs().end()), std::back_inserter(allocations), [&](instruction_ref x) {
{ return instruction::get_output_alias(x, true);
auto last2 = (*ins2)->inputs().back(); });
if(last2->name() == concat_opt.allocate())
{ if (std::any_of(allocations.begin(), allocations.end(), [&](auto x) { return x->name() != concat_opt.allocate(); }))
allocations.push_back(last2); continue;
}
}
// Need to sort the allocations, so that we know where to // Need to sort the allocations, so that we know where to
// insert the "super"-allocation // insert the "super"-allocation
std::sort( std::sort(
...@@ -53,13 +52,13 @@ void eliminate_concat::apply(program& p) const ...@@ -53,13 +52,13 @@ void eliminate_concat::apply(program& p) const
// Move "super" allocation to the front // Move "super" allocation to the front
auto first = allocations.front(); auto first = allocations.front();
auto super = p.move_instruction(last, first); auto super = p.move_instruction(last, first);
// Replace each allocation with a load
std::size_t offset = 0; std::size_t offset = 0;
for(auto x : allocations) for(auto alloc : allocations)
{ {
migraphx::op::load op{x->get_shape(), offset}; op::load op{alloc->get_shape(), offset};
// migraphx::op::load op{x->get_shape(), 0}; p.replace_instruction(alloc, op, {super});
p.replace_instruction(x, op, {super}); offset += alloc->get_shape().bytes();
offset += x->get_shape().bytes();
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
......
...@@ -73,7 +73,7 @@ struct instruction ...@@ -73,7 +73,7 @@ struct instruction
argument eval() const; argument eval() const;
static instruction_ref get_output_alias(instruction_ref ins); static instruction_ref get_output_alias(instruction_ref ins, bool shallow=false);
private: private:
// internal // internal
......
...@@ -191,11 +191,13 @@ argument instruction::eval() const ...@@ -191,11 +191,13 @@ argument instruction::eval() const
return {}; return {};
} }
instruction_ref instruction::get_output_alias(instruction_ref ins) instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
if(i < 0) if(i < 0)
return ins; return ins;
if(shallow)
return ins->inputs().at(i);
return get_output_alias(ins->inputs().at(i)); return get_output_alias(ins->inputs().at(i));
} }
......
...@@ -14,9 +14,6 @@ namespace gpu { ...@@ -14,9 +14,6 @@ namespace gpu {
void eliminate_workspace::apply(program& p) const void eliminate_workspace::apply(program& p) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
return;
std::size_t n = 0; std::size_t n = 0;
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -32,11 +29,14 @@ void eliminate_workspace::apply(program& p) const ...@@ -32,11 +29,14 @@ void eliminate_workspace::apply(program& p) const
allocs.push_back(ins); allocs.push_back(ins);
} }
} }
auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}}); if (n > 0)
for(auto&& a : allocs)
{ {
p.replace_instruction(a, ws); auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}});
p.remove_instruction(a); for(auto&& a : allocs)
{
p.replace_instruction(a, ws);
p.remove_instruction(a);
}
} }
} }
......
...@@ -63,9 +63,9 @@ struct allocate ...@@ -63,9 +63,9 @@ struct allocate
} }
}; };
struct fred_op struct simple_op
{ {
std::string name() const { return "fred_op"; } std::string name() const { return "simple_op"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraphx::check_shapes{inputs}.has(1); migraphx::check_shapes{inputs}.has(1);
...@@ -77,44 +77,145 @@ struct fred_op ...@@ -77,44 +77,145 @@ struct fred_op
{ {
return args.at(0); return args.at(0);
} }
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
template<class... Ts>
migraphx::shape create_shape(Ts... xs)
{
return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}
using load = migraphx::op::load;
using identity = migraphx::op::identity;
TEST_CASE(simple)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 =
p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 =
p.add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2);
std::size_t axis = 0;
auto a3 = p.add_instruction(
allocate{create_shape(2)});
p.add_instruction(concat(axis), p1, p2, a3);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(
allocate{create_shape(2)});
auto l1 =
p.add_instruction(load{create_shape(1), 0}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 =
p.add_instruction(load{create_shape(1), 4}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
p.add_instruction(identity{}, a1, p1, p2);
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
TEST_CASE(nested)
{
auto concat_test_program = [](auto& p) {
auto a1 =
p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 =
p.add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2);
std::size_t axis = 0;
auto a3 = p.add_instruction(
allocate{create_shape(2)});
return p.add_instruction(concat(axis), p1, p2, a3);
};
auto create_test_program = [&] {
migraphx::program p;
auto concat1 = concat_test_program(p);
auto concat2 = concat_test_program(p);
std::size_t axis = 0;
auto a1 = p.add_instruction(
allocate{create_shape(4)});
p.add_instruction(concat(axis), concat1, concat2, a1);
return p;
};
auto concat_control_program = [](auto& p, auto a1) {
auto l1 =
p.add_instruction(load{create_shape(1), 0}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 =
p.add_instruction(load{create_shape(1), 4}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
return p.add_instruction(identity{}, a1, p1, p2);
};
auto create_control_program = [&] {
migraphx::program p;
auto a1 = p.add_instruction(
allocate{create_shape(4)});
auto l1 =
p.add_instruction(load{create_shape(2), 0}, a1);
auto concat1 = concat_control_program(p, l1);
auto l2 =
p.add_instruction(load{create_shape(2), 8}, a1);
auto concat2 = concat_control_program(p, l2);
p.add_instruction(identity{}, a1, concat1, concat2);
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
TEST_CASE(basic) TEST_CASE(basic)
{ {
auto create_test_program = []() { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = auto a1 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = auto a2 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(simple_op{}, a2);
auto a3 = auto a3 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3); auto p3 = p.add_instruction(simple_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction( auto a4 = p.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = []() { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction( auto a1 = p.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction( auto l1 = p.add_instruction(
migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0},
{a1}); {a1});
auto p1 = p.add_instruction(fred_op{}, l1); auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction( auto l2 = p.add_instruction(
migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512},
{a1}); {a1});
auto p2 = p.add_instruction(fred_op{}, l2); auto p2 = p.add_instruction(simple_op{}, l2);
auto l3 = p.add_instruction( auto l3 = p.add_instruction(
migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280},
{a1}); {a1});
auto p3 = p.add_instruction(fred_op{}, l3); auto p3 = p.add_instruction(simple_op{}, l3);
p.add_instruction(migraphx::op::identity{}, {a1, p1, p2, p3}); p.add_instruction(identity{}, {a1, p1, p2, p3});
return p; return p;
}; };
...@@ -127,34 +228,34 @@ TEST_CASE(basic) ...@@ -127,34 +228,34 @@ TEST_CASE(basic)
TEST_CASE(wont_work) TEST_CASE(wont_work)
{ {
auto create_test_program = []() { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = auto a1 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = auto a2 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(simple_op{}, a2);
auto a3 = auto a3 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3); auto p3 = p.add_instruction(simple_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction( auto a4 = p.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = []() { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = auto a1 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = auto a2 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(simple_op{}, a2);
auto a3 = auto a3 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3); auto p3 = p.add_instruction(simple_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction( auto a4 = p.add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment