Commit 15674793 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Debug simplify alg reshapes.

Looks like we're missing a case for squeeze since we can't seem to reshape
without requiring more memory.

TBD
parent 3591803e
......@@ -156,12 +156,20 @@ struct reshape
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
{
std::cout << "kabo" << std::endl;
return nullopt;
}
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
{
std::cout << "kaboom" << std::endl;
std::cout << "i=" << i << " r=" << r << std::endl;
std::cout << "idim=" << idim << " rdim=" << rdim << std::endl;
return nullopt;
}
i += n;
rstrides.push_back(istrides[i]);
}
......@@ -171,7 +179,10 @@ struct reshape
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
{
std::cout << "kaboomie" << std::endl;
return nullopt;
}
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
......@@ -234,8 +245,10 @@ struct reshape
auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
{
std::cout << inputs.front() << std::endl;
MIGRAPHX_THROW("Reshape on axis that is not packed.");
}
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " +
......
......@@ -544,8 +544,15 @@ struct find_inner_broadcast
return 1;
return 3;
}));
// m.debug_print();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
op->debug_print();
ins->debug_print();
broadcasts.front()->debug_print();
auto r_op = m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
r_op->debug_print();
std::cout << "\n" << std::endl;
}
};
......@@ -1196,6 +1203,7 @@ struct find_div_const
auto ins = r.result;
auto c_ins = r.instructions["c"];
// m.debug_print();
auto recip = m.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs();
......
......@@ -667,6 +667,31 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_input_broadcast)
{
auto mb = migraphx::op::multibroadcast{{2, 576, 1}};
auto mb2 = migraphx::op::multibroadcast{{2, 576, 1}};
migraphx::module m1;
{
auto fl_cons =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1}}));
auto fl_cons_b = m1.add_instruction(mb2, fl_cons);
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {2, 576, 768}});
auto yb = m1.add_instruction(mb, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), fl_cons_b, yb);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), sum);
// auto div = m1.add_instruction(migraphx::make_op("div"), sqrt);
m1.add_instruction(pass_op{}, sqrt);
}
run_pass(m1);
run_pass(m1);
run_pass(m1);
m1.debug_print();
EXPECT(true);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment