"docs/zh_CN/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "5ce474bc8dbd841641d930d5e72098fd8d745f44"
Unverified Commit 4f447b03 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into batchnorm-rewrite

parents d536e5aa db70de8e
...@@ -39,6 +39,8 @@ else() ...@@ -39,6 +39,8 @@ else()
set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "")
endif() endif()
set(MIGRAPHX_ENABLE_TF Off CACHE BOOL "")
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
......
...@@ -10,6 +10,7 @@ add_library(migraphx ...@@ -10,6 +10,7 @@ add_library(migraphx
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
env.cpp env.cpp
......
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(program& p) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
}
if(ins == last)
{
if(ins->name() == "identity")
{
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
}
}
break;
}
}
p.remove_instructions(std::next(last), p.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove identity instructions.
*/
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -7,24 +7,6 @@ ...@@ -7,24 +7,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
/// Create a program from an onnx file /// Create a program from an onnx file
program parse_onnx(const std::string& name); program parse_onnx(const std::string& name);
......
...@@ -1369,6 +1369,25 @@ struct undefined ...@@ -1369,6 +1369,25 @@ struct undefined
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
}; };
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -17,6 +17,7 @@ struct program; ...@@ -17,6 +17,7 @@ struct program;
struct schedule struct schedule
{ {
schedule_model model{}; schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -7,25 +7,7 @@ ...@@ -7,25 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct unknown /// Create a program from a tf pb file (default is nhwc format)
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
/// Create a program from an onnx file
program parse_tf(const std::string& name, bool is_nhwc); program parse_tf(const std::string& name, bool is_nhwc);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -1206,7 +1206,7 @@ struct onnx_parser ...@@ -1206,7 +1206,7 @@ struct onnx_parser
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
result.push_back(prog.add_instruction(unknown{node.op_type()}, args)); result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
} }
else else
{ {
......
...@@ -12,7 +12,12 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -12,7 +12,12 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden CXX_VISIBILITY_PRESET hidden
) )
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu) if(MIGRAPHX_ENABLE_TF)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu) target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp>
#endif
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -155,8 +160,16 @@ PYBIND11_MODULE(migraphx, m) ...@@ -155,8 +160,16 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf",
&migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
......
...@@ -341,6 +341,8 @@ struct stream_info ...@@ -341,6 +341,8 @@ struct stream_info
void schedule::apply(program& p) const void schedule::apply(program& p) const
{ {
if(not enable)
return;
stream_info si; stream_info si;
auto last = std::prev(p.end()); auto last = std::prev(p.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
...@@ -25,6 +26,8 @@ namespace migraphx { ...@@ -25,6 +26,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx) const std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
...@@ -32,6 +35,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -32,6 +35,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
return return
{ {
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
...@@ -53,13 +57,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -53,13 +57,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, enabled(MIGRAPHX_ENABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{}, eliminate_workspace{},
eliminate_allocation{"hip::allocate"}, eliminate_allocation{"hip::allocate"},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{},
eliminate_identity{}
}; };
// clang-format on // clang-format on
} }
......
...@@ -525,7 +525,7 @@ struct tf_parser ...@@ -525,7 +525,7 @@ struct tf_parser
} }
if(ops.count(node.op()) == 0) if(ops.count(node.op()) == 0)
{ {
instructions[name] = prog.add_instruction(unknown{node.op()}, args); instructions[name] = prog.add_instruction(op::unknown{node.op()}, args);
} }
else else
{ {
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct eliminate_identity_target
{
std::string name() const { return "eliminate_identity"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_identity{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two);
p.add_instruction(sum_op{}, one_identity, two_identity);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end_dependency)
{
migraphx::program p;
auto one = p.add_literal(1.0);
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(!std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3.0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -392,8 +392,8 @@ TEST_CASE(unknown_test) ...@@ -392,8 +392,8 @@ TEST_CASE(unknown_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1); auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::unknown{"Unknown"}, l2); p.add_instruction(migraphx::op::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx"); auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment