Commit 5d4fcb59 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 44774583 301b7605
......@@ -638,6 +638,29 @@ struct pad
}
};
struct as_shape
{
shape s;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
}
std::string name() const { return "as_shape"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
assert(inputs.front().elements() == s.elements());
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot
{
float alpha = 1.0;
......
......@@ -9,7 +9,18 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(const std::string& name)
// Reshapers that can't handle nonstandard input shapes
bool is_nonstandard_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape"
};
// clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous";
}
bool is_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
......@@ -19,26 +30,27 @@ bool is_reshaper(const std::string& name)
"contiguous"
};
// clang-format on
return contains(names, name);
return contains(names, ins->name()) and not is_nonstandard_reshaper(ins);
}
void simplify_reshapes::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(not is_reshaper(ins->name()))
if(not is_reshaper(ins))
continue;
if(ins->outputs().size() != 1)
continue;
if(is_reshaper(ins->outputs().front()->name()))
if(is_reshaper(ins->outputs().front()))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name()))
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front());
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
......@@ -58,6 +70,13 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second);
}
}
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass)
EXPECT(result != get_2x2());
}
TEST_CASE(reshape_transpose)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = p.add_parameter("x", s);
auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s);
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment