Commit 7647329c authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into lrn

parents 209fd870 88f4aad8
......@@ -41,8 +41,9 @@ void dead_code_elimination::apply(program& p) const
// Skip the last instruction
if(i == last)
break;
// Skip instruction with empty shape as output unless its a builtin
if(i->get_shape().elements() == 0 and not(i->name().front() == '@'))
// Skip instruction with empty shape as output unless its a builtin or undefined
if(i->get_shape().elements() == 0 and not(i->name().front() == '@') and
not(i->name() == "undefined"))
continue;
assert(bidistance(p, i, last) > 0);
fix([&](auto self, auto leaf) {
......
......@@ -5,6 +5,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Forward declare any_cast
template <class T>
const T& any_cast(const T&);
namespace detail {
template <class U>
......
......@@ -7,17 +7,17 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct context;
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
......
......@@ -382,6 +382,17 @@ struct contiguous
auto t = inputs.at(0).type();
return {t, lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
};
struct concat
......
......@@ -9,8 +9,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Reshapers that can't handle nonstandard input shapes
bool is_nonstandard_reshaper(instruction_ref ins)
bool is_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
......@@ -18,31 +17,39 @@ bool is_nonstandard_reshaper(instruction_ref ins)
"contiguous"
};
// clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous";
return contains(names, ins->name());
}
bool is_reshaper(instruction_ref ins)
bool is_transpose_output(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape",
"transpose",
// "broadcast",
"contiguous"
};
// clang-format on
return contains(names, ins->name()) and not is_nonstandard_reshaper(ins);
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins)
{
if(ins->inputs().size() != 1)
return ins;
if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front();
return ins;
}
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(not is_reshaper(ins))
continue;
if(ins->outputs().size() != 1)
if(ins->outputs().empty() and ins != end)
continue;
if(is_reshaper(ins->outputs().front()))
if(is_reshaper(ins))
{
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
......@@ -71,6 +78,22 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second);
}
}
else if(ins->name() == "transpose")
{
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
}
}
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
......
......@@ -324,14 +324,7 @@ struct cpu_contiguous
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
return op.compute(output_shape, std::move(args));
}
};
......
#include <migraphx/dead_code_elimination.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct dce_target
......@@ -111,4 +112,21 @@ TEST_CASE(depth_test)
EXPECT(result != migraphx::literal{4});
}
TEST_CASE(undefined_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto undef = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
EXPECT(not p.has_instruction(undef));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/operation.hpp>
#include <migraphx/context.hpp>
#include <sstream>
#include <string>
#include "test.hpp"
......
......@@ -27,9 +27,9 @@ TEST_CASE(double_contig)
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(p.begin(), p.end()) == 4);
auto result = p.eval({});
EXPECT(result == get_2x2());
EXPECT(result != get_2x2());
}
TEST_CASE(double_transpose)
......@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass)
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
// std::cout << p << std::endl;
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
auto result = p.eval({});
......@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose)
EXPECT(std::distance(p.begin(), p.end()) == n);
}
TEST_CASE(transpose_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
TEST_CASE(transpose_double_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -7,17 +7,17 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct context;
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment