"vscode:/vscode.git/clone" did not exist on "44f79c7e0c869fee2f15705a1ae1216fec0d879e"
Commit db70de8e authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into disable-schedule-pass

parents 0211d91c 1cf81ce3
......@@ -10,6 +10,7 @@ add_library(migraphx
eliminate_allocation.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
eliminate_identity.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
env.cpp
......
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(program& p) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
}
if(ins == last)
{
if(ins->name() == "identity")
{
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
}
}
break;
}
}
p.remove_instructions(std::next(last), p.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove identity instructions.
*/
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -17,6 +17,7 @@
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/schedule.hpp>
......@@ -34,6 +35,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
return
{
dead_code_elimination{},
eliminate_identity{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
rewrite_rnn{},
......@@ -61,7 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
check_context<context>{},
dead_code_elimination{}
dead_code_elimination{},
eliminate_identity{}
};
// clang-format on
}
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct eliminate_identity_target
{
std::string name() const { return "eliminate_identity"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_identity{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two);
p.add_instruction(sum_op{}, one_identity, two_identity);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end_dependency)
{
migraphx::program p;
auto one = p.add_literal(1.0);
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(!std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3.0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment