Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -10,6 +10,7 @@ namespace migraphx { ...@@ -10,6 +10,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Schedule instructions for concurrent execution * Schedule instructions for concurrent execution
...@@ -19,7 +20,7 @@ struct schedule ...@@ -19,7 +20,7 @@ struct schedule
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,6 +16,7 @@ namespace migraphx { ...@@ -16,6 +16,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
struct operation; struct operation;
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -26,11 +27,11 @@ struct schedule_model ...@@ -26,11 +27,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const; void sched(module& p, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const; void record(module& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
...@@ -43,9 +44,9 @@ struct schedule_model ...@@ -43,9 +44,9 @@ struct schedule_model
* struct schedule_model * struct schedule_model
* { * {
* std::size_t concurrency() const; * std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const; * void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const; * void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const; * void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const; * std::size_t weight(const operation& op) const;
* }; * };
* *
...@@ -120,19 +121,19 @@ struct schedule_model ...@@ -120,19 +121,19 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency(); return (*this).private_detail_te_get_handle().concurrency();
} }
void sched(program& p, instruction_ref ins, std::size_t n) const void sched(module& p, instruction_ref ins, std::size_t n) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n); (*this).private_detail_te_get_handle().sched(p, ins, n);
} }
void wait(program& p, instruction_ref ins, std::size_t wait_id) const void wait(module& p, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id); (*this).private_detail_te_get_handle().wait(p, ins, wait_id);
} }
void record(program& p, instruction_ref ins, std::size_t wait_id) const void record(module& p, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id); (*this).private_detail_te_get_handle().record(p, ins, wait_id);
...@@ -158,11 +159,11 @@ struct schedule_model ...@@ -158,11 +159,11 @@ struct schedule_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0; virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0; virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0; virtual std::size_t weight(const operation& op) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -195,19 +196,19 @@ struct schedule_model ...@@ -195,19 +196,19 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); } std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override void sched(module& p, instruction_ref ins, std::size_t n) const override
{ {
private_detail_te_value.sched(p, ins, n); private_detail_te_value.sched(p, ins, n);
} }
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override void wait(module& p, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.wait(p, ins, wait_id); private_detail_te_value.wait(p, ins, wait_id);
} }
void record(program& p, instruction_ref ins, std::size_t wait_id) const override void record(module& p, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.record(p, ins, wait_id); private_detail_te_value.record(p, ins, wait_id);
......
...@@ -8,6 +8,7 @@ namespace migraphx { ...@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Simplify many algebraic instructions to more efficient versions. * Simplify many algebraic instructions to more efficient versions.
...@@ -15,7 +16,7 @@ struct program; ...@@ -15,7 +16,7 @@ struct program;
struct simplify_algebra struct simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Eliminate redundant reshapes. * Eliminate redundant reshapes.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct simplify_reshapes struct simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -61,7 +61,7 @@ int main(int argc, char const* argv[]) ...@@ -61,7 +61,7 @@ int main(int argc, char const* argv[])
{ {
// GPU target // GPU target
prog.compile(migraphx::gpu::target{}); prog.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m; migraphx::parameter_map m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes()) for(auto&& x : prog.get_parameter_shapes())
{ {
......
...@@ -124,7 +124,7 @@ int main(int argc, char const* argv[]) ...@@ -124,7 +124,7 @@ int main(int argc, char const* argv[])
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto* ptr = input.data(); auto* ptr = input.data();
migraphx::program::parameter_map m; migraphx::parameter_map m;
m["output"] = m["output"] =
migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output"))); migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++) for(int i = 0; i < 20; i++)
......
This diff is collapsed.
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const void memory_coloring::apply(module& p) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
......
...@@ -67,7 +67,7 @@ using interval_ptr = live_interval*; ...@@ -67,7 +67,7 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify) memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify) : p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{ {
instr2_live.clear(); instr2_live.clear();
...@@ -145,7 +145,7 @@ struct memory_coloring_impl ...@@ -145,7 +145,7 @@ struct memory_coloring_impl
return (i1->offset > i2->offset); return (i1->offset > i2->offset);
} }
}; };
program* p_program; module* p_program;
std::unordered_map<const instruction*, interval_ptr> instr2_live; std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals. // universe of live intervals.
std::vector<live_interval> live_intervals; std::vector<live_interval> live_intervals;
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) void run_passes(module& modl, const std::vector<pass>& passes, tracer trace)
{ {
for(const auto& p : passes) for(const auto& p : passes)
{ {
trace("Pass: ", p.name()); trace("Pass: ", p.name());
p.apply(prog); p.apply(modl);
trace(prog); trace(modl);
#ifndef NDEBUG #ifndef NDEBUG
trace("Validate ..."); trace("Validate ...");
auto invalid = prog.validate(); auto invalid = modl.validate();
if(invalid != prog.end()) if(invalid != modl.end())
{ {
auto index = std::distance(prog.begin(), invalid); auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name()); std::to_string(index) + ": " + invalid->name());
} }
......
...@@ -20,7 +20,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -20,7 +20,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
void propagate_constant::apply(program& p) const void propagate_constant::apply(module& p) const
{ {
for(auto i : iterator_for(p)) for(auto i : iterator_for(p))
{ {
......
...@@ -256,8 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -256,8 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module") py::class_<migraphx::module>(m, "module").def("print", [](const migraphx::module& mm) {
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; }); std::cout << mm << std::endl;
});
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); }) .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
...@@ -277,12 +278,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -277,12 +278,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("fast_math") = true) py::arg("fast_math") = true)
.def("get_main_module", .def("get_main_module",
[](migraphx::program& p) { [](migraphx::program& p) {
auto mm = p.get_main_module(); auto* mm = p.get_main_module();
return migraphx::module_wrap{mm}; return migraphx::module{*mm};
}) })
.def("run", .def("run",
[](migraphx::program& p, py::dict params) { [](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm; migraphx::parameter_map pm;
for(auto x : params) for(auto x : params)
{ {
std::string key = x.first.cast<std::string>(); std::string key = x.first.cast<std::string>();
...@@ -389,7 +390,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -389,7 +390,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
&migraphx::quantize_int8, &migraphx::quantize_int8,
py::arg("prog"), py::arg("prog"),
py::arg("t"), py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{}, py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"}); py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU #ifdef HAVE_GPU
......
...@@ -27,7 +27,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -27,7 +27,7 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog, instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins,
...@@ -59,11 +59,11 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -59,11 +59,11 @@ instruction_ref insert_quant_ins(program& prog,
if(scaled_ins->get_shape().type() != shape::float_type) if(scaled_ins->get_shape().type() != shape::float_type)
{ {
float_ins = float_ins =
prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins); modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
} }
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale); std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(float_ins->get_shape(), vec_scale)); auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins); scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
} }
auto shifted_ins = scaled_ins; auto shifted_ins = scaled_ins;
...@@ -72,27 +72,27 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -72,27 +72,27 @@ instruction_ref insert_quant_ins(program& prog,
auto float_ins = shifted_ins; auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type) if(shifted_ins->get_shape().type() != shape::float_type)
{ {
float_ins = prog.insert_instruction( float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins); insert_loc, op::convert{shape::float_type}, shifted_ins);
} }
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift); std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(float_ins->get_shape(), vec_shift)); auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins); shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
} }
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins); auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens(); auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = prog.add_literal(127.0f); auto max_clip = modl.add_literal(127.0f);
auto min_clip = prog.add_literal(-128.0f); auto min_clip = modl.add_literal(-128.0f);
max_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip); max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip); min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto clipped_ins = auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip); modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins); quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
} }
else else
{ {
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins); quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins);
} }
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
...@@ -107,8 +107,9 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -107,8 +107,9 @@ instruction_ref insert_quant_ins(program& prog,
// truncate of the input to get the fp16. // truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
auto* mm = prog.get_main_module();
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(*mm))
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
break; break;
...@@ -139,7 +140,7 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -139,7 +140,7 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
} }
else else
{ {
input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16); input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
} }
converted_inputs.push_back(input_fp16); converted_inputs.push_back(input_fp16);
} }
...@@ -162,18 +163,18 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -162,18 +163,18 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
// check the dead code case to avoid assert // check the dead code case to avoid assert
bool output_empty = ins->outputs().empty(); bool output_empty = ins->outputs().empty();
auto ins_orig_type = auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins); mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty) if(!output_empty)
{ {
prog.replace_instruction(ins, ins_orig_type); mm->replace_instruction(ins, ins_orig_type);
} }
} }
prog.replace_instruction(ins, op, converted_inputs); mm->replace_instruction(ins, op, converted_inputs);
} }
} }
static void ins_quantize_int8(program& prog, static void ins_quantize_int8(module& modl,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref>& converted_inputs, std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params) const std::vector<std::pair<float, float>>& ins_quant_params)
...@@ -195,14 +196,14 @@ static void ins_quantize_int8(program& prog, ...@@ -195,14 +196,14 @@ static void ins_quantize_int8(program& prog,
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta)); int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type) if(shape::int32_type == orig_type)
{ {
prog.replace_instruction( modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
} }
else else
{ {
auto quant_dot = prog.insert_instruction( auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot); modl.replace_instruction(ins, op::convert{orig_type}, quant_dot);
} }
} }
// either alpha or beta cannot be quantized because of too big // either alpha or beta cannot be quantized because of too big
...@@ -213,51 +214,51 @@ static void ins_quantize_int8(program& prog, ...@@ -213,51 +214,51 @@ static void ins_quantize_int8(program& prog,
{ {
converted_inputs.pop_back(); converted_inputs.pop_back();
} }
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot); auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape(); auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha); std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha = auto l_alpha =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha)); modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f) if(inputs.size() == 3 and dot_op.beta != 0.0f)
{ {
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot); auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta); std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta = auto l_beta =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta)); modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{}; instruction_ref beta_c{};
if(orig_type != shape::float_type) if(orig_type != shape::float_type)
{ {
auto fp32_c = auto fp32_c =
prog.insert_instruction(ins, op::convert{shape::float_type}, inputs.back()); modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, fp32_c); beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
} }
else else
{ {
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
} }
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c); modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
} }
else else
{ {
auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c); auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
prog.replace_instruction(ins, op::convert{orig_type}, f_res); modl.replace_instruction(ins, op::convert{orig_type}, f_res);
} }
} }
else else
{ {
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot); modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
} }
else else
{ {
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot); auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab); modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
} }
} }
} }
...@@ -274,7 +275,7 @@ static void ins_quantize_int8(program& prog, ...@@ -274,7 +275,7 @@ static void ins_quantize_int8(program& prog,
auto group = conv_op.group; auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first); auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = prog.insert_instruction( auto quant_conv = modl.insert_instruction(
ins, ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group}, op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs); converted_inputs);
...@@ -282,25 +283,25 @@ static void ins_quantize_int8(program& prog, ...@@ -282,25 +283,25 @@ static void ins_quantize_int8(program& prog,
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor); std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold) if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{ {
auto l_factor = prog.add_literal( auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end())); literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor); modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
} }
// convert quant_conv output to float type, multiply the factor and // convert quant_conv output to float type, multiply the factor and
// conver back to original type // conver back to original type
else else
{ {
auto float_conv = auto float_conv =
prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv); modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor)); auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv); modl.replace_instruction(ins, op::mul{}, l_factor, float_conv);
} }
else else
{ {
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv); auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv); modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
} }
} }
} }
...@@ -338,10 +339,11 @@ void quantize_int8_impl(program& prog, ...@@ -338,10 +339,11 @@ void quantize_int8_impl(program& prog,
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
} }
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0; std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins; std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index; std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(*mm))
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
break; break;
...@@ -398,7 +400,7 @@ void quantize_int8_impl(program& prog, ...@@ -398,7 +400,7 @@ void quantize_int8_impl(program& prog,
else else
{ {
quant_input = insert_quant_ins( quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second); *mm, input, quant_type, map_quant_ins, param.first, param.second);
} }
converted_inputs.push_back(quant_input); converted_inputs.push_back(quant_input);
} }
...@@ -414,7 +416,7 @@ void quantize_int8_impl(program& prog, ...@@ -414,7 +416,7 @@ void quantize_int8_impl(program& prog,
continue; continue;
} }
ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params); ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
} }
if(quant_param_index != quant_params.size()) if(quant_param_index != quant_params.size())
...@@ -425,7 +427,7 @@ void quantize_int8_impl(program& prog, ...@@ -425,7 +427,7 @@ void quantize_int8_impl(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<program::parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
// insert capture operator // insert capture operator
...@@ -439,7 +441,7 @@ void quantize_int8(program& prog, ...@@ -439,7 +441,7 @@ void quantize_int8(program& prog,
// quantization scale and shift // quantization scale and shift
for(auto&& arg : calibration) for(auto&& arg : calibration)
{ {
program::parameter_map m; parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes()) for(auto&& x : cap_prog.get_parameter_shapes())
{ {
if(arg.count(x.first) > 0) if(arg.count(x.first) > 0)
...@@ -464,7 +466,7 @@ std::size_t capture_arguments(program& prog, ...@@ -464,7 +466,7 @@ std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func) const std::function<void(std::size_t, std::vector<argument>)>& func)
{ {
auto* mm = prog.get_main_module();
size_t num_quant_params = 0; size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution // the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"}; std::set<std::string> op_names = {"dot", "convolution"};
...@@ -476,7 +478,7 @@ std::size_t capture_arguments(program& prog, ...@@ -476,7 +478,7 @@ std::size_t capture_arguments(program& prog,
} }
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(*mm))
{ {
if(not contains(ins_names, ins->name())) if(not contains(ins_names, ins->name()))
{ {
...@@ -494,7 +496,7 @@ std::size_t capture_arguments(program& prog, ...@@ -494,7 +496,7 @@ std::size_t capture_arguments(program& prog,
} }
else else
{ {
new_ins = prog.insert_instruction( new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input); std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins; ins_map[input] = new_ins;
} }
......
...@@ -22,7 +22,7 @@ struct find_dot_add ...@@ -22,7 +22,7 @@ struct find_dot_add
match::name("dot")(match::nargs(2)).bind("dot")))); match::name("dot")(match::nargs(2)).bind("dot"))));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot_ins = r.instructions["dot"]; auto dot_ins = r.instructions["dot"];
...@@ -36,7 +36,7 @@ struct find_dot_add ...@@ -36,7 +36,7 @@ struct find_dot_add
}; };
} // namespace } // namespace
void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); } void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(program& p) const void rewrite_batchnorm::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const void rewrite_pooling::apply(module& prog) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const void rewrite_rnn::apply(module& prog) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
...@@ -47,13 +47,13 @@ void rewrite_rnn::apply(program& prog) const ...@@ -47,13 +47,13 @@ void rewrite_rnn::apply(program& prog) const
} }
} }
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
{ {
assert(ins->name() == "rnn"); assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will // could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing // append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments // an onnx file. Another case is user can have num of arguments
// when writing their program. // when writing their module.
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
...@@ -210,7 +210,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -210,7 +210,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const operation& actv_func) const
...@@ -336,7 +336,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -336,7 +336,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
} }
} }
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
{ {
assert(ins->name() == "gru"); assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins); const auto actv_funcs = gru_actv_funcs(ins);
...@@ -502,7 +502,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -502,7 +502,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -685,7 +685,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -685,7 +685,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
} }
// for lstm operators // for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{ {
assert(ins->name() == "lstm"); assert(ins->name() == "lstm");
auto args = ins->inputs(); auto args = ins->inputs();
...@@ -927,7 +927,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -927,7 +927,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -1158,7 +1158,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1158,7 +1158,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
{ {
bool is_var_lens = false; bool is_var_lens = false;
if(seq_lens != prog.end()) if(seq_lens != prog.end())
...@@ -1188,7 +1188,7 @@ bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_ ...@@ -1188,7 +1188,7 @@ bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_
} }
std::size_t std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
{ {
bool is_var_lens = is_variable_seq_lens(prog, seq_lens); bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
auto input_shape = input->get_shape(); auto input_shape = input->get_shape();
...@@ -1204,7 +1204,7 @@ rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction ...@@ -1204,7 +1204,7 @@ rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction
return length; return length;
} }
instruction_ref rewrite_rnn::replace_last_hs_output(program& prog, instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
...@@ -1243,7 +1243,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog, ...@@ -1243,7 +1243,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
return result_ins; return result_ins;
} }
void rewrite_rnn::replace_last_cell_output(program& prog, void rewrite_rnn::replace_last_cell_output(module& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
...@@ -1281,7 +1281,7 @@ void rewrite_rnn::replace_last_cell_output(program& prog, ...@@ -1281,7 +1281,7 @@ void rewrite_rnn::replace_last_cell_output(program& prog,
} }
} }
instruction_ref rewrite_rnn::pad_hidden_states(program& prog, instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const instruction_ref hs) const
......
...@@ -103,7 +103,7 @@ struct stream_info ...@@ -103,7 +103,7 @@ struct stream_info
} }
}; };
std::size_t assign_streams(program& p, std::size_t n) std::size_t assign_streams(module& p, std::size_t n)
{ {
assert(n > 0); assert(n > 0);
partition critical; partition critical;
...@@ -182,7 +182,7 @@ struct stream_info ...@@ -182,7 +182,7 @@ struct stream_info
} }
}; };
void sort(program& p, std::size_t) const void sort(module& p, std::size_t)
{ {
std::set<weight_ins, compare_weight_ins> children; std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited; std::unordered_map<instruction_ref, std::size_t> visited;
...@@ -335,7 +335,7 @@ struct stream_info ...@@ -335,7 +335,7 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(program& p) const find_concurrent_instructions(module& p) const
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
...@@ -378,7 +378,7 @@ struct stream_info ...@@ -378,7 +378,7 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p) get_conflicts(module& p)
{ {
using conflict_table_type = using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
...@@ -464,7 +464,7 @@ struct stream_info ...@@ -464,7 +464,7 @@ struct stream_info
} }
}; };
void schedule::apply(program& p) const void schedule::apply(module& p) const
{ {
if(not enable) if(not enable)
return; return;
......
...@@ -50,7 +50,7 @@ struct find_mul_conv ...@@ -50,7 +50,7 @@ struct find_mul_conv
match::name("broadcast").bind("a"))); match::name("broadcast").bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
...@@ -86,7 +86,7 @@ struct find_mul_slice_conv ...@@ -86,7 +86,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a"))); match::name("broadcast")(match::is_constant()).bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto slice_ins = r.instructions["slice"]; auto slice_ins = r.instructions["slice"];
...@@ -169,7 +169,7 @@ struct find_mul_add ...@@ -169,7 +169,7 @@ struct find_mul_add
match::is_constant().bind("a"))); match::is_constant().bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
...@@ -191,7 +191,7 @@ struct find_add_lit_broadcast ...@@ -191,7 +191,7 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -211,7 +211,7 @@ struct find_double_add_lit_broadcast ...@@ -211,7 +211,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y"))); match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -249,7 +249,7 @@ struct find_inner_broadcast ...@@ -249,7 +249,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -294,7 +294,7 @@ struct find_concat_op ...@@ -294,7 +294,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise"); return op.name() == "broadcast" or op.attributes().contains("pointwise");
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis; auto axis = any_cast<op::concat>(ins->get_operator()).axis;
...@@ -425,7 +425,7 @@ struct find_splits ...@@ -425,7 +425,7 @@ struct find_splits
return groups; return groups;
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -520,7 +520,7 @@ struct find_split_concat ...@@ -520,7 +520,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat"))))); match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -618,7 +618,7 @@ struct find_add_convs ...@@ -618,7 +618,7 @@ struct find_add_convs
input.strides()[3] * n}}; input.strides()[3] * n}};
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_conv = r.instructions["a"]; auto a_conv = r.instructions["a"];
...@@ -689,7 +689,7 @@ struct find_conv_dot_horiz_fusion ...@@ -689,7 +689,7 @@ struct find_conv_dot_horiz_fusion
{ {
auto matcher() const { return horiz_conv_dot(); } auto matcher() const { return horiz_conv_dot(); }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -762,7 +762,7 @@ struct find_div_const ...@@ -762,7 +762,7 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c"))); return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
...@@ -782,7 +782,7 @@ struct find_sub_const ...@@ -782,7 +782,7 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c"))); return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
...@@ -803,7 +803,7 @@ struct find_rsqrt ...@@ -803,7 +803,7 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x"))))); match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -828,7 +828,7 @@ struct find_split_reshape ...@@ -828,7 +828,7 @@ struct find_split_reshape
.bind("reshape"); .bind("reshape");
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
...@@ -904,7 +904,7 @@ struct find_split_transpose ...@@ -904,7 +904,7 @@ struct find_split_transpose
.bind("trans"); .bind("trans");
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"]; auto trans = r.instructions["trans"];
...@@ -949,7 +949,7 @@ struct find_split_transpose ...@@ -949,7 +949,7 @@ struct find_split_transpose
} }
}; };
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(module& p) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
......
...@@ -66,7 +66,7 @@ struct find_reshaper ...@@ -66,7 +66,7 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names()))); match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
...@@ -113,7 +113,7 @@ struct find_nop_reshapes ...@@ -113,7 +113,7 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front()); p.replace_instruction(ins, ins->inputs().front());
...@@ -128,7 +128,7 @@ struct find_transpose ...@@ -128,7 +128,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
...@@ -201,7 +201,7 @@ struct find_nested_slice ...@@ -201,7 +201,7 @@ struct find_nested_slice
return result; return result;
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto slice = ins->inputs().front(); auto slice = ins->inputs().front();
...@@ -230,7 +230,7 @@ struct find_concat_transpose ...@@ -230,7 +230,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto trans_inputs = ins->inputs(); auto trans_inputs = ins->inputs();
...@@ -279,7 +279,7 @@ struct find_nested_concat ...@@ -279,7 +279,7 @@ struct find_nested_concat
return op.axis; return op.axis;
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto axis = get_axis(ins); auto axis = get_axis(ins);
...@@ -298,7 +298,7 @@ struct find_nested_concat ...@@ -298,7 +298,7 @@ struct find_nested_concat
} }
}; };
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(module& p) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment