Commit 0369e974 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'batch_report' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 3a474fca d70fd0df
...@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const ...@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const
return result; return result;
} }
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
File mode changed from 100755 to 100644
#include <migraphx/decompose.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
}
};
struct find_dot_alpha
{
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", a_ins->get_shape().lens()}}), alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
}
};
} // namespace
void decompose::apply(module& p) const { match::find_matches(p, find_dot_add{}, find_dot_alpha{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -6,6 +6,7 @@ add_executable(driver ...@@ -6,6 +6,7 @@ add_executable(driver
resnet50.cpp resnet50.cpp
inceptionv3.cpp inceptionv3.cpp
alexnet.cpp alexnet.cpp
marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility # Copy driver for backwards compatibility
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -144,10 +145,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast42; migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096}; multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4); auto mx42 = mm->add_instruction(multibroadcast42, mx4);
migraphx::op::dot dot43; float dot43_alpha = 1;
dot43.alpha = 1; float dot43_beta = 1;
dot43.beta = 1; auto mx43 = migraphx::add_apply_alpha_beta(
auto mx43 = mm->add_instruction(dot43, mx40, mx41, mx42); *mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44; migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43); auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45; migraphx::op::identity identity45;
...@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -158,10 +159,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast47; migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096}; multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto mx47 = mm->add_instruction(multibroadcast47, mx2);
migraphx::op::dot dot48; float dot48_alpha = 1;
dot48.alpha = 1; float dot48_beta = 1;
dot48.beta = 1; auto mx48 = migraphx::add_apply_alpha_beta(
auto mx48 = mm->add_instruction(dot48, mx45, mx46, mx47); *mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49; migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48); auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50; migraphx::op::transpose transpose50;
...@@ -170,10 +171,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -170,10 +171,10 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast51; migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000}; multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0); auto mx51 = mm->add_instruction(multibroadcast51, mx0);
migraphx::op::dot dot52; float dot52_alpha = 1;
dot52.alpha = 1; float dot52_beta = 1;
dot52.beta = 1; migraphx::add_apply_alpha_beta(
mm->add_instruction(dot52, mx49, mx50, mx51); *mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
return p; return p;
} }
......
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -2225,10 +2226,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2225,10 +2226,10 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::multibroadcast multibroadcast798; migraphx::op::multibroadcast multibroadcast798;
multibroadcast798.output_lens = {batch, 1000}; multibroadcast798.output_lens = {batch, 1000};
auto mx798 = mm->add_instruction(multibroadcast798, mx0); auto mx798 = mm->add_instruction(multibroadcast798, mx0);
migraphx::op::dot dot799; float dot799_alpha = 1;
dot799.alpha = 1; float dot799_beta = 1;
dot799.beta = 1; migraphx::add_apply_alpha_beta(
mm->add_instruction(dot799, mx796, mx797, mx798); *mm, {mx796, mx797, mx798}, migraphx::make_op("dot"), dot799_alpha, dot799_beta);
return p; return p;
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "verify.hpp" #include "verify.hpp"
#include "perf.hpp" #include "perf.hpp"
#include "models.hpp" #include "models.hpp"
#include "marker_roctx.hpp"
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
...@@ -479,7 +480,24 @@ struct perf : command<perf> ...@@ -479,7 +480,24 @@ struct perf : command<perf>
std::cout << "Allocating params ... " << std::endl; std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p); auto m = c.params(p);
std::cout << "Running performance report ... " << std::endl; std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m); p.perf_report(std::cout, n, m, c.l.batch);
}
};
struct roctx : command<roctx>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
auto rtx = create_marker_roctx();
p.mark(m, std::move(rtx));
} }
}; };
......
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
class marker_roctx
{
std::function<void(const char*)> sym_roctx_mark;
std::function<uint64_t(const char*)> sym_roctx_range_start;
std::function<void(uint64_t)> sym_roctx_range_stop;
std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop;
uint64_t range_id;
public:
marker_roctx()
{
dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"};
sym_roctx_mark = lib.get_function<void(const char*)>("roctxMarkA");
sym_roctx_range_start = lib.get_function<uint64_t(const char*)>("roctxRangeStartA");
sym_roctx_range_stop = lib.get_function<void(uint64_t)>("roctxRangeStop");
sym_roctx_range_push = lib.get_function<int(const char*)>("roctxRangePushA");
sym_roctx_range_pop = lib.get_function<int()>("roctxRangePop");
sym_roctx_mark("rocTX marker created.");
}
void mark_start(instruction_ref ins_ref)
{
std::string text = "Marker start: " + ins_ref->name();
sym_roctx_range_push(text.c_str());
}
void mark_stop(instruction_ref) { sym_roctx_range_pop(); }
void mark_start(const program&) { range_id = sym_roctx_range_start("0"); }
void mark_stop(const program&) { sym_roctx_range_stop(range_id); }
};
marker create_marker_roctx() { return marker_roctx(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#define MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#include <migraphx/marker.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
marker create_marker_roctx();
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -1228,10 +1229,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1228,10 +1229,10 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::multibroadcast multibroadcast442; migraphx::op::multibroadcast multibroadcast442;
multibroadcast442.output_lens = {batch, 1000}; multibroadcast442.output_lens = {batch, 1000};
auto mx442 = mm->add_instruction(multibroadcast442, mx0); auto mx442 = mm->add_instruction(multibroadcast442, mx0);
migraphx::op::dot dot443; float dot443_alpha = 1;
dot443.alpha = 1; float dot443_beta = 1;
dot443.beta = 1; migraphx::add_apply_alpha_beta(
mm->add_instruction(dot443, mx440, mx441, mx442); *mm, {mx440, mx441, mx442}, migraphx::make_op("dot"), dot443_alpha, dot443_beta);
return p; return p;
} }
......
...@@ -45,6 +45,7 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer) ...@@ -45,6 +45,7 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{ {
dlerror();
void* symbol = dlsym(impl->handle.get(), name.c_str()); void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr) if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name); MIGRAPHX_THROW("Symbol not found: " + name);
......
...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = { static const std::vector<std::string> skip_op_names = {
"convert", "get_tuple_elem", "if", "loop"}; "convert", "get_tuple_elem", "if", "loop", "roialign"};
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
......
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins)
{
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
return r;
}
static void create_pointwise_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
auto* pm = mpm.create_module("pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
pointwise_inputs.push_back(input);
param_map[input] = pm->add_parameter("x" + std::to_string(param_map.size()),
shape{input->get_shape().type()});
}
else
{
param_map[input] = pm->add_literal(scalar);
}
}
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return param_map[input]; });
auto r = pm->add_instruction(ins->get_operator(), inputs);
pm->add_return({r});
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
}
}
static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output)
{
module_ref pm = ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);
auto last = std::prev(pm->end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);
std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i));
input_map[input] = param;
}
// Add the new parameter and additional inputs
for(auto i : range(output->inputs().size()))
{
auto input = output->inputs()[i];
auto param = xm->get_parameter("x" + std::to_string(i));
if(input == ins)
{
map_ins[param] = last->inputs().front();
input_map[input] = map_ins[param];
}
// Avoid duplicate paramter inputs
else if(contains(input_map, input))
{
map_ins[param] = input_map[input];
}
else
{
map_ins[param] =
pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
return inputs;
}
static bool find_pointwise_modules(module& m)
{
bool changed = false;
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty())
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto new_inputs = append_pointwise_module(*it, ins);
m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs());
m.replace_instruction(ins, *it);
m.move_instruction(*it, ins);
changed = true;
}
return changed;
}
void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 8; i++)
{
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP
#include "migraphx/make_op.hpp"
#include "migraphx/normalize_attributes.hpp"
#include "migraphx/operation.hpp"
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta);
template <typename T = float>
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, pos, args, op, literal{T{alpha}}, literal{T{beta}});
}
template <typename T = float>
instruction_ref add_apply_alpha_beta(module& m,
const std::vector<instruction_ref>& args,
const operation& op,
T alpha = 1,
T beta = 0)
{
return insert_apply_alpha_beta(m, m.end(), args, op, alpha, beta);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_APPLY_ALPHA_BETA_HPP
...@@ -67,6 +67,9 @@ struct argument : raw_data<argument> ...@@ -67,6 +67,9 @@ struct argument : raw_data<argument>
std::vector<argument> get_sub_objects() const; std::vector<argument> get_sub_objects() const;
/// Return the ith element
argument element(std::size_t i) const;
private: private:
void assign_buffer(std::function<char*()> d); void assign_buffer(std::function<char*()> d);
struct data_t struct data_t
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP #define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <string>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
/** struct fuse_pointwise
* Decompose operators.
*/
struct decompose
{ {
std::string name() const { return "decompose"; } std::string name() const { return "fuse_pointwise"; }
void apply(module& p) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#endif
...@@ -152,30 +152,4 @@ struct instruction ...@@ -152,30 +152,4 @@ struct instruction
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(&*x);
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std
#endif #endif
...@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS {
struct instruction; struct instruction;
using instruction_ref = std::list<instruction>::iterator; using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const migraphx::instruction_ref& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(migraphx::as_address(x));
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return migraphx::as_address(x) == migraphx::as_address(y);
}
};
} // namespace std
#endif #endif
#ifndef MIGRAPHX_GUARD_MARKER_HPP
#define MIGRAPHX_GUARD_MARKER_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// Marker is an interface to general marking functions, such as rocTX markers.
#else
/*
* Type-erased interface for:
*
* struct marker
* {
* void mark_start(instruction_ref ins_ref) ;
* void mark_start(const program& prog) ;
* void mark_stop(instruction_ref ins) ;
* void mark_stop(const program& prog) ;
* };
*
*/
struct marker
{
// Constructors
marker() = default;
template <typename PrivateDetailTypeErasedT>
marker(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
marker& operator=(PrivateDetailTypeErasedT value)
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
marker rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void mark_start(instruction_ref ins_ref)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(ins_ref);
}
void mark_start(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(prog);
}
void mark_stop(instruction_ref ins)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(ins);
}
void mark_stop(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(prog);
}
friend bool is_shared(const marker& private_detail_x, const marker& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void mark_start(instruction_ref ins_ref) = 0;
virtual void mark_start(const program& prog) = 0;
virtual void mark_stop(instruction_ref ins) = 0;
virtual void mark_stop(const program& prog) = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void mark_start(instruction_ref ins_ref) override
{
private_detail_te_value.mark_start(ins_ref);
}
void mark_start(const program& prog) override { private_detail_te_value.mark_start(prog); }
void mark_stop(instruction_ref ins) override { private_detail_te_value.mark_stop(ins); }
void mark_stop(const program& prog) override { private_detail_te_value.mark_stop(prog); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(marker& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const marker& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -96,6 +96,11 @@ struct module ...@@ -96,6 +96,11 @@ struct module
instruction_ref move_instruction(instruction_ref src, instruction_ref dst); instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst); instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts> template <class... Ts>
instruction_ref add_literal(Ts&&... xs) instruction_ref add_literal(Ts&&... xs)
{ {
...@@ -110,6 +115,8 @@ struct module ...@@ -110,6 +115,8 @@ struct module
instruction_ref add_return(std::vector<instruction_ref> args); instruction_ref add_return(std::vector<instruction_ref> args);
instruction_ref replace_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment