Commit 62eea2df authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into gru_operator

parents 3854a5e1 88f4aad8
...@@ -41,8 +41,9 @@ void dead_code_elimination::apply(program& p) const ...@@ -41,8 +41,9 @@ void dead_code_elimination::apply(program& p) const
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin // Skip instruction with empty shape as output unless its a builtin or undefined
if(i->get_shape().elements() == 0 and not(i->name().front() == '@')) if(i->get_shape().elements() == 0 and not(i->name().front() == '@') and
not(i->name() == "undefined"))
continue; continue;
assert(bidistance(p, i, last) > 0); assert(bidistance(p, i, last) > 0);
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
......
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Forward declare any_cast
template <class T>
const T& any_cast(const T&);
namespace detail { namespace detail {
template <class U> template <class U>
......
...@@ -7,17 +7,17 @@ ...@@ -7,17 +7,17 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct context;
#ifdef DOXYGEN #ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All /// The operation interface represents an action an instruction will perform. All
......
...@@ -358,6 +358,17 @@ struct contiguous ...@@ -358,6 +358,17 @@ struct contiguous
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
return {t, lens}; return {t, lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
}; };
struct concat struct concat
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Reshapers that can't handle nonstandard input shapes bool is_reshaper(instruction_ref ins)
bool is_nonstandard_reshaper(instruction_ref ins)
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -18,57 +17,81 @@ bool is_nonstandard_reshaper(instruction_ref ins) ...@@ -18,57 +17,81 @@ bool is_nonstandard_reshaper(instruction_ref ins)
"contiguous" "contiguous"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous"; return contains(names, ins->name());
} }
bool is_reshaper(instruction_ref ins) bool is_transpose_output(instruction_ref ins)
{ {
// clang-format off if(ins->outputs().size() != 1)
static const std::unordered_set<std::string> names = { return false;
"reshape", if(ins->outputs().front()->name() == "contiguous")
"transpose", return is_transpose_output(ins->outputs().front());
// "broadcast", return ins->outputs().front()->name() == "transpose";
"contiguous" }
};
// clang-format on instruction_ref find_transpose_input(instruction_ref ins)
return contains(names, ins->name()) and not is_nonstandard_reshaper(ins); {
if(ins->inputs().size() != 1)
return ins;
if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front();
return ins;
} }
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
{ {
auto end = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins)) if(ins->outputs().empty() and ins != end)
continue;
if(ins->outputs().size() != 1)
continue;
if(is_reshaper(ins->outputs().front()))
continue; continue;
// Gather reshapes if(is_reshaper(ins))
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
assert(p.has_instruction(reshapes.back()->inputs().front())); continue;
auto input = reshapes.back()->inputs().front(); // Gather reshapes
reshapes.push_back(input); std::vector<instruction_ref> reshapes{ins};
} while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{ {
r = std::make_pair(*start, *last); auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
break; return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
} }
} }
if(r.first != r.second) else if(ins->name() == "transpose")
{ {
p.replace_instruction(r.first, r.second); if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
} }
} }
// Replace all reshapes with as_shape // Replace all reshapes with as_shape
......
...@@ -287,14 +287,7 @@ struct cpu_contiguous ...@@ -287,14 +287,7 @@ struct cpu_contiguous
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
assert(output_shape.standard()); return op.compute(output_shape, std::move(args));
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
} }
}; };
......
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp> #include <test.hpp>
struct dce_target struct dce_target
...@@ -111,4 +112,21 @@ TEST_CASE(depth_test) ...@@ -111,4 +112,21 @@ TEST_CASE(depth_test)
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(undefined_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto undef = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
EXPECT(not p.has_instruction(undef));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -108,10 +108,27 @@ TEST_CASE(reshape_shape) ...@@ -108,10 +108,27 @@ TEST_CASE(reshape_shape)
expect_shape(output, migraphx::op::reshape{new_shape}, input); expect_shape(output, migraphx::op::reshape{new_shape}, input);
} }
for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}}) for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{ {
throws_shape(migraphx::op::reshape{new_shape}, input); throws_shape(migraphx::op::reshape{new_shape}, input);
} }
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
for(auto& it : minus1_tests)
{
expect_shape(it.second, migraphx::op::reshape{it.first}, input);
}
} }
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
......
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/context.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
......
...@@ -27,9 +27,9 @@ TEST_CASE(double_contig) ...@@ -27,9 +27,9 @@ TEST_CASE(double_contig)
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 4);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == get_2x2()); EXPECT(result != get_2x2());
} }
TEST_CASE(double_transpose) TEST_CASE(double_transpose)
...@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass)
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
// std::cout << p << std::endl;
// TODO: Fix this // TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1); // EXPECT(std::distance(p.begin(), p.end()) == 1);
auto result = p.eval({}); auto result = p.eval({});
...@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose) ...@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose)
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
} }
TEST_CASE(transpose_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
TEST_CASE(transpose_double_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -7,17 +7,17 @@ ...@@ -7,17 +7,17 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct context;
#ifdef DOXYGEN #ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All /// The operation interface represents an action an instruction will perform. All
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment