Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
...@@ -88,6 +88,7 @@ struct cpp_generator_impl ...@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {}; std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{ {
impl->point_op_map[op_name] = code; impl->point_op_map[op_name] = code;
...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(args), std::back_inserter(args),
[&](auto i) { return names.at(i); }); [&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
}); });
return f; return f;
} }
......
...@@ -9,26 +9,6 @@ ...@@ -9,26 +9,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); } void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const void dead_code_elimination::apply(module& m) const
...@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const ...@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity") i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(m, i, last) > 0); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
if(not m.has_instruction(leaf)) if(not m.has_instruction(leaf))
return; return;
if(leaf->outputs().empty()) if(leaf->outputs().empty())
{ {
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(), std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end()); leaf->inputs().end());
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(m, last, leaf) < 0); assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins); assert(leaf != ins);
if(leaf->name() != "@param") if(leaf->name() != "@param")
m.move_instruction(leaf, m.end()); m.move_instruction(leaf, m.end());
......
...@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu19; migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18); auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20; migraphx::op::pooling pooling20;
pooling20.mode = "max"; pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0}; pooling20.padding = {0, 0};
pooling20.stride = {2, 2}; pooling20.stride = {2, 2};
pooling20.lengths = {3, 3}; pooling20.lengths = {3, 3};
...@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu24; migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23); auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25; migraphx::op::pooling pooling25;
pooling25.mode = "max"; pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0}; pooling25.padding = {0, 0};
pooling25.stride = {2, 2}; pooling25.stride = {2, 2};
pooling25.lengths = {3, 3}; pooling25.lengths = {3, 3};
...@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu37; migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36); auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38; migraphx::op::pooling pooling38;
pooling38.mode = "max"; pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0}; pooling38.padding = {0, 0};
pooling38.stride = {2, 2}; pooling38.stride = {2, 2};
pooling38.lengths = {3, 3}; pooling38.lengths = {3, 3};
......
...@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492; migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491); auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493; migraphx::op::pooling pooling493;
pooling493.mode = "max"; pooling493.mode = migraphx::op::pooling_mode::max;
pooling493.padding = {0, 0}; pooling493.padding = {0, 0};
pooling493.stride = {2, 2}; pooling493.stride = {2, 2};
pooling493.lengths = {3, 3}; pooling493.lengths = {3, 3};
...@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499; migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498); auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500; migraphx::op::pooling pooling500;
pooling500.mode = "max"; pooling500.mode = migraphx::op::pooling_mode::max;
pooling500.padding = {0, 0}; pooling500.padding = {0, 0};
pooling500.stride = {2, 2}; pooling500.stride = {2, 2};
pooling500.lengths = {3, 3}; pooling500.lengths = {3, 3};
...@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518; migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517); auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519; migraphx::op::pooling pooling519;
pooling519.mode = "average"; pooling519.mode = migraphx::op::pooling_mode::average;
pooling519.padding = {1, 1}; pooling519.padding = {1, 1};
pooling519.stride = {1, 1}; pooling519.stride = {1, 1};
pooling519.lengths = {3, 3}; pooling519.lengths = {3, 3};
...@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541; migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540); auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542; migraphx::op::pooling pooling542;
pooling542.mode = "average"; pooling542.mode = migraphx::op::pooling_mode::average;
pooling542.padding = {1, 1}; pooling542.padding = {1, 1};
pooling542.stride = {1, 1}; pooling542.stride = {1, 1};
pooling542.lengths = {3, 3}; pooling542.lengths = {3, 3};
...@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564; migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563); auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565; migraphx::op::pooling pooling565;
pooling565.mode = "average"; pooling565.mode = migraphx::op::pooling_mode::average;
pooling565.padding = {1, 1}; pooling565.padding = {1, 1};
pooling565.stride = {1, 1}; pooling565.stride = {1, 1};
pooling565.lengths = {3, 3}; pooling565.lengths = {3, 3};
...@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581; migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580); auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582; migraphx::op::pooling pooling582;
pooling582.mode = "max"; pooling582.mode = migraphx::op::pooling_mode::max;
pooling582.padding = {0, 0}; pooling582.padding = {0, 0};
pooling582.stride = {2, 2}; pooling582.stride = {2, 2};
pooling582.lengths = {3, 3}; pooling582.lengths = {3, 3};
...@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610; migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609); auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611; migraphx::op::pooling pooling611;
pooling611.mode = "average"; pooling611.mode = migraphx::op::pooling_mode::average;
pooling611.padding = {1, 1}; pooling611.padding = {1, 1};
pooling611.stride = {1, 1}; pooling611.stride = {1, 1};
pooling611.lengths = {3, 3}; pooling611.lengths = {3, 3};
...@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642; migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641); auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643; migraphx::op::pooling pooling643;
pooling643.mode = "average"; pooling643.mode = migraphx::op::pooling_mode::average;
pooling643.padding = {1, 1}; pooling643.padding = {1, 1};
pooling643.stride = {1, 1}; pooling643.stride = {1, 1};
pooling643.lengths = {3, 3}; pooling643.lengths = {3, 3};
...@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674; migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673); auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675; migraphx::op::pooling pooling675;
pooling675.mode = "average"; pooling675.mode = migraphx::op::pooling_mode::average;
pooling675.padding = {1, 1}; pooling675.padding = {1, 1};
pooling675.stride = {1, 1}; pooling675.stride = {1, 1};
pooling675.lengths = {3, 3}; pooling675.lengths = {3, 3};
...@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706; migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705); auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707; migraphx::op::pooling pooling707;
pooling707.mode = "average"; pooling707.mode = migraphx::op::pooling_mode::average;
pooling707.padding = {1, 1}; pooling707.padding = {1, 1};
pooling707.stride = {1, 1}; pooling707.stride = {1, 1};
pooling707.lengths = {3, 3}; pooling707.lengths = {3, 3};
...@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729; migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728); auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730; migraphx::op::pooling pooling730;
pooling730.mode = "max"; pooling730.mode = migraphx::op::pooling_mode::max;
pooling730.padding = {0, 0}; pooling730.padding = {0, 0};
pooling730.stride = {2, 2}; pooling730.stride = {2, 2};
pooling730.lengths = {3, 3}; pooling730.lengths = {3, 3};
...@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1; concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756); auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758; migraphx::op::pooling pooling758;
pooling758.mode = "average"; pooling758.mode = migraphx::op::pooling_mode::average;
pooling758.padding = {1, 1}; pooling758.padding = {1, 1};
pooling758.stride = {1, 1}; pooling758.stride = {1, 1};
pooling758.lengths = {3, 3}; pooling758.lengths = {3, 3};
...@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1; concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787); auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789; migraphx::op::pooling pooling789;
pooling789.mode = "average"; pooling789.mode = migraphx::op::pooling_mode::average;
pooling789.padding = {1, 1}; pooling789.padding = {1, 1};
pooling789.stride = {1, 1}; pooling789.stride = {1, 1};
pooling789.lengths = {3, 3}; pooling789.lengths = {3, 3};
...@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1; concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792); auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794; migraphx::op::pooling pooling794;
pooling794.mode = "average"; pooling794.mode = migraphx::op::pooling_mode::average;
pooling794.padding = {0, 0}; pooling794.padding = {0, 0};
pooling794.stride = {8, 8}; pooling794.stride = {8, 8};
pooling794.lengths = {8, 8}; pooling794.lengths = {8, 8};
......
...@@ -508,8 +508,10 @@ struct roctx : command<roctx> ...@@ -508,8 +508,10 @@ struct roctx : command<roctx>
struct op : command<op> struct op : command<op>
{ {
bool show_ops = false; bool show_ops = false;
std::string op_name{};
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
ap(show_ops, ap(show_ops,
{"--list", "-l"}, {"--list", "-l"},
ap.help("List all the operators of MIGraphX"), ap.help("List all the operators of MIGraphX"),
...@@ -522,6 +524,12 @@ struct op : command<op> ...@@ -522,6 +524,12 @@ struct op : command<op>
for(const auto& name : get_operators()) for(const auto& name : get_operators())
std::cout << name << std::endl; std::cout << name << std::endl;
} }
else
{
auto op = load_op(op_name);
std::cout << op_name << ": " << std::endl;
std::cout << to_pretty_json_string(op.to_value()) << std::endl;
}
} }
}; };
......
...@@ -17,7 +17,7 @@ class marker_roctx ...@@ -17,7 +17,7 @@ class marker_roctx
std::function<int(const char*)> sym_roctx_range_push; std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop; std::function<int()> sym_roctx_range_pop;
uint64_t range_id; uint64_t range_id = 0;
public: public:
marker_roctx() marker_roctx()
......
...@@ -87,6 +87,6 @@ target get_target(bool gpu) ...@@ -87,6 +87,6 @@ target get_target(bool gpu)
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); } void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu269; migraphx::op::relu relu269;
auto mx269 = mm->add_instruction(relu269, mx268); auto mx269 = mm->add_instruction(relu269, mx268);
migraphx::op::pooling pooling270; migraphx::op::pooling pooling270;
pooling270.mode = "max"; pooling270.mode = migraphx::op::pooling_mode::max;
pooling270.padding = {1, 1}; pooling270.padding = {1, 1};
pooling270.stride = {2, 2}; pooling270.stride = {2, 2};
pooling270.lengths = {3, 3}; pooling270.lengths = {3, 3};
...@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu438; migraphx::op::relu relu438;
auto mx438 = mm->add_instruction(relu438, mx437); auto mx438 = mm->add_instruction(relu438, mx437);
migraphx::op::pooling pooling439; migraphx::op::pooling pooling439;
pooling439.mode = "average"; pooling439.mode = migraphx::op::pooling_mode::average;
pooling439.padding = {0, 0}; pooling439.padding = {0, 0};
pooling439.stride = {1, 1}; pooling439.stride = {1, 1};
pooling439.lengths = {7, 7}; pooling439.lengths = {7, 7};
......
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(module& p) const void eliminate_allocation::apply(module& m) const
{ {
assert(alignment > 0); assert(alignment > 0);
std::size_t n = 0; std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
...@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const ...@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
} }
if(n > 0) if(n > 0)
{ {
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}}); auto mem = m.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs) for(auto&& pp : allocs)
{ {
auto ins = pp.first; auto ins = pp.first;
auto s = ins->get_shape(); auto s = ins->get_shape();
auto offset = pp.second; auto offset = pp.second;
p.replace_instruction( m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem); ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
} }
} }
......
...@@ -11,7 +11,7 @@ namespace migraphx { ...@@ -11,7 +11,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range> template <class Range>
void cse_range(module& p, Range&& r) void cse_range(module& m, Range&& r)
{ {
std::unordered_multimap<std::string, instruction_ref> instructions; std::unordered_multimap<std::string, instruction_ref> instructions;
std::unordered_set<instruction_ref> processed_ins; std::unordered_set<instruction_ref> processed_ins;
...@@ -30,19 +30,24 @@ void cse_range(module& p, Range&& r) ...@@ -30,19 +30,24 @@ void cse_range(module& p, Range&& r)
continue; continue;
if(*eq != *ins) if(*eq != *ins)
continue; continue;
p.replace_instruction(ins, eq); m.replace_instruction(ins, eq);
processed_ins.emplace(ins); processed_ins.emplace(ins);
auto outputs = eq->outputs(); std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return m.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) { std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y); return std::distance(eq, x) < std::distance(eq, y);
}); });
cse_range(p, outputs); cse_range(m, outputs);
} }
instructions.emplace(ins->name(), ins); instructions.emplace(ins->name(), ins);
} }
} }
void eliminate_common_subexpression::apply(module& p) const { cse_range(p, iterator_for(p)); } void eliminate_common_subexpression::apply(module& m) const { cse_range(m, iterator_for(m)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(module& p) const void eliminate_concat::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Look for the concat operator // Look for the concat operator
if(ins->name() != concat_opt.name()) if(ins->name() != concat_opt.name())
...@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const ...@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std::sort(sorted_allocations.begin(), std::sort(sorted_allocations.begin(),
sorted_allocations.end(), sorted_allocations.end(),
[&](instruction_ref x, instruction_ref y) { [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y); return std::distance(m.begin(), x) < std::distance(m.begin(), y);
}); });
// Move "super" allocation to the front // Move "super" allocation to the front
auto first = sorted_allocations.front(); auto first = sorted_allocations.front();
auto super = p.move_instruction(last, first); auto super = m.move_instruction(last, first);
// Replace each allocation with a load // Replace each allocation with a load
std::size_t offset = 0; std::size_t offset = 0;
for(auto alloc : allocations) for(auto alloc : allocations)
{ {
op::load op{alloc->get_shape(), offset}; op::load op{alloc->get_shape(), offset};
p.replace_instruction(alloc, op, {super}); m.replace_instruction(alloc, op, {super});
offset += alloc->get_shape().bytes(); offset += alloc->get_shape().bytes();
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::make_op("identity"), args); m.replace_instruction(ins, migraphx::make_op("identity"), args);
} }
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -69,38 +70,52 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -69,38 +70,52 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& p) const void eliminate_contiguous::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) std::vector<instruction_ref> const_instruction;
for(auto ins : iterator_for(m))
{ {
// return instruction should have inputs with standard shape // return instruction should have inputs with standard shape
if(ins->name() == "@return") if(ins->name() == "@return")
continue; continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args;
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() == op_name)
{ {
auto new_args = args; auto prev = arg->inputs().front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, ins->module_inputs())) if(try_compute_shape(ins, new_args, mod_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
else if(prev->can_eval()) else if(prev->can_eval())
{ {
auto c = op::contiguous{}; const_instruction.push_back(arg);
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
} }
} }
} }
} }
// Perform evaluations in parallel
std::vector<argument> literals(const_instruction.size());
par_for(const_instruction.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instruction[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
});
for(size_t i = 0; i < const_instruction.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instruction[i], l);
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = { static const std::vector<std::string> skip_op_names = {"convert",
"convert", "get_tuple_elem", "if", "loop", "roialign"}; "get_tuple_elem",
"if",
"loop",
"roialign",
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
......
...@@ -8,21 +8,21 @@ ...@@ -8,21 +8,21 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(module& p) const void eliminate_identity::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Skip the first instruction, since we always process the previous // Skip the first instruction, since we always process the previous
// instruction // instruction
if(ins == p.begin()) if(ins == m.begin())
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
if(i->name() == "identity") if(i->name() == "identity")
{ {
p.replace_instruction(i, i->inputs().front()); m.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end()); m.move_instruction(i, m.end());
} }
if(ins == last) if(ins == last)
{ {
...@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const ...@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const instruction_ref& identity_input = ins->inputs().front(); const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1) if(identity_input->outputs().size() == 1)
{ {
p.move_instruction(identity_input, i); m.move_instruction(identity_input, i);
// since this is the last instruction, removing it only // since this is the last instruction, removing it only
// requires changing "last" and calling remove below // requires changing "last" and calling remove below
last = std::prev(last); last = std::prev(last);
...@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const ...@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break; break;
} }
} }
p.remove_instructions(std::next(last), p.end()); m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -13,7 +13,7 @@ struct adjust_allocation ...@@ -13,7 +13,7 @@ struct adjust_allocation
{ {
allocation_model model; allocation_model model;
std::string name() const { return "adjust_allocation"; } std::string name() const { return "adjust_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -32,18 +32,22 @@ struct allocation_model ...@@ -32,18 +32,22 @@ struct allocation_model
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct allocation_model struct allocation_model
* { {
* std::string name() const; //
* std::string copy() const; std::string name() const;
* operation allocate(const shape& s) const; //
* operation preallocate(const shape& s,std::string id) const; std::string copy() const;
* }; //
* operation allocate(const shape& s) const;
*/ //
operation preallocate(const shape& s, std::string id) const;
};
#else
struct allocation_model struct allocation_model
{ {
...@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x) ...@@ -260,6 +264,7 @@ inline const ValueType& any_cast(const allocation_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -16,7 +16,7 @@ struct stream_race ...@@ -16,7 +16,7 @@ struct stream_race
instruction_ref before; instruction_ref before;
}; };
std::vector<stream_race> analyze_streams(const module& p, const stream_model& m); std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/type_name.hpp>
#include <cassert>
#include <string_view>
#include <typeindex>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct any_ptr
{
any_ptr() = default;
template <class T>
any_ptr(T* p) : ptr(p), ti(typeid(T*)), name(get_name<T*>())
{
}
any_ptr(void* p, std::string_view pname) : ptr(p), name(pname) {}
void* get(std::string_view n) const
{
if(name != n)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} +
" != " + std::string{n});
return ptr;
}
template <class T>
T get() const
{
static_assert(std::is_pointer<T>{}, "Must be a pointer");
assert(ptr != nullptr);
if(ti and std::type_index{typeid(T)} != *ti)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
else if(name != get_name<T>())
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
return reinterpret_cast<T>(ptr);
}
void* unsafe_get() const { return ptr; }
private:
void* ptr = nullptr;
optional<std::type_index> ti = nullopt;
std::string_view name = "";
template <class T>
static const std::string& get_name()
{
return get_type_name<std::remove_cv_t<std::remove_pointer_t<T>>>();
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
...@@ -13,7 +13,7 @@ struct module; ...@@ -13,7 +13,7 @@ struct module;
struct auto_contiguous struct auto_contiguous
{ {
std::string name() const { return "auto_contiguous"; } std::string name() const { return "auto_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment