Commit 23546ab5 authored by Khalique's avatar Khalique
Browse files

add contiguous to flatten

parent eacf042e
......@@ -39,7 +39,7 @@ struct flatten
std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens();
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
......
......@@ -47,7 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const
{
return contains({"gather"}, op_name);
return contains({"flatten", "gather"}, op_name);
}
instruction_ref parse(const op_desc& opd,
......
......@@ -131,4 +131,17 @@ TEST_CASE(non_standard_return_input)
EXPECT(std::distance(m.begin(), m.end()) == count);
}
TEST_CASE(non_standard_flatten_op)
{
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 6, 6, 6}});
auto t = m.add_instruction(migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(migraphx::make_op("flatten"), c);
auto count = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(std::distance(m.begin(), m.end()) == count);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment