"include/vscode:/vscode.git/clone" did not exist on "675aa69e45381bf8b179f9faf7db8cf726c0e004"
Commit 23546ab5 authored by Khalique's avatar Khalique
Browse files

add contiguous to flatten

parent eacf042e
...@@ -39,7 +39,7 @@ struct flatten ...@@ -39,7 +39,7 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
auto x = auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
......
...@@ -47,7 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -47,7 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
{ {
return contains({"gather"}, op_name); return contains({"flatten", "gather"}, op_name);
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
......
...@@ -131,4 +131,17 @@ TEST_CASE(non_standard_return_input) ...@@ -131,4 +131,17 @@ TEST_CASE(non_standard_return_input)
EXPECT(std::distance(m.begin(), m.end()) == count); EXPECT(std::distance(m.begin(), m.end()) == count);
} }
TEST_CASE(non_standard_flatten_op)
{
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 6, 6, 6}});
auto t = m.add_instruction(migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(migraphx::make_op("flatten"), c);
auto count = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(std::distance(m.begin(), m.end()) == count);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment