"mmdet3d/models/vscode:/vscode.git/clone" did not exist on "d75836ea43697248542b9ad7425c96db3c227556"
Commit 5d4fcb59 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 44774583 301b7605
...@@ -638,6 +638,29 @@ struct pad ...@@ -638,6 +638,29 @@ struct pad
} }
}; };
struct as_shape
{
shape s;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
}
std::string name() const { return "as_shape"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
assert(inputs.front().elements() == s.elements());
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
......
...@@ -9,7 +9,18 @@ ...@@ -9,7 +9,18 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(const std::string& name) // Reshapers that can't handle nonstandard input shapes
bool is_nonstandard_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape"
};
// clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous";
}
bool is_reshaper(instruction_ref ins)
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -19,26 +30,27 @@ bool is_reshaper(const std::string& name) ...@@ -19,26 +30,27 @@ bool is_reshaper(const std::string& name)
"contiguous" "contiguous"
}; };
// clang-format on // clang-format on
return contains(names, name); return contains(names, ins->name()) and not is_nonstandard_reshaper(ins);
} }
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins->name())) if(not is_reshaper(ins))
continue; continue;
if(ins->outputs().size() != 1) if(ins->outputs().size() != 1)
continue; continue;
if(is_reshaper(ins->outputs().front()->name())) if(is_reshaper(ins->outputs().front()))
continue; continue;
// Gather reshapes // Gather reshapes
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front())); assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front()); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
} }
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
...@@ -58,6 +70,13 @@ void simplify_reshapes::apply(program& p) const ...@@ -58,6 +70,13 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second); p.replace_instruction(r.first, r.second);
} }
} }
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass)
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
TEST_CASE(reshape_transpose)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = p.add_parameter("x", s);
auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s);
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment