Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -10,6 +10,7 @@ namespace migraphx { ...@@ -10,6 +10,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Schedule instructions for concurrent execution * Schedule instructions for concurrent execution
...@@ -19,7 +20,7 @@ struct schedule ...@@ -19,7 +20,7 @@ struct schedule
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,6 +16,7 @@ namespace migraphx { ...@@ -16,6 +16,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
struct operation; struct operation;
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -26,11 +27,11 @@ struct schedule_model ...@@ -26,11 +27,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const; void sched(module& p, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const; void record(module& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
...@@ -43,9 +44,9 @@ struct schedule_model ...@@ -43,9 +44,9 @@ struct schedule_model
* struct schedule_model * struct schedule_model
* { * {
* std::size_t concurrency() const; * std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const; * void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const; * void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const; * void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const; * std::size_t weight(const operation& op) const;
* }; * };
* *
...@@ -120,19 +121,19 @@ struct schedule_model ...@@ -120,19 +121,19 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency(); return (*this).private_detail_te_get_handle().concurrency();
} }
void sched(program& p, instruction_ref ins, std::size_t n) const void sched(module& p, instruction_ref ins, std::size_t n) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n); (*this).private_detail_te_get_handle().sched(p, ins, n);
} }
void wait(program& p, instruction_ref ins, std::size_t wait_id) const void wait(module& p, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id); (*this).private_detail_te_get_handle().wait(p, ins, wait_id);
} }
void record(program& p, instruction_ref ins, std::size_t wait_id) const void record(module& p, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id); (*this).private_detail_te_get_handle().record(p, ins, wait_id);
...@@ -158,11 +159,11 @@ struct schedule_model ...@@ -158,11 +159,11 @@ struct schedule_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0; virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0; virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0; virtual std::size_t weight(const operation& op) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -195,19 +196,19 @@ struct schedule_model ...@@ -195,19 +196,19 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); } std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override void sched(module& p, instruction_ref ins, std::size_t n) const override
{ {
private_detail_te_value.sched(p, ins, n); private_detail_te_value.sched(p, ins, n);
} }
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override void wait(module& p, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.wait(p, ins, wait_id); private_detail_te_value.wait(p, ins, wait_id);
} }
void record(program& p, instruction_ref ins, std::size_t wait_id) const override void record(module& p, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.record(p, ins, wait_id); private_detail_te_value.record(p, ins, wait_id);
......
...@@ -8,6 +8,7 @@ namespace migraphx { ...@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Simplify many algebraic instructions to more efficient versions. * Simplify many algebraic instructions to more efficient versions.
...@@ -15,7 +16,7 @@ struct program; ...@@ -15,7 +16,7 @@ struct program;
struct simplify_algebra struct simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Eliminate redundant reshapes. * Eliminate redundant reshapes.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct simplify_reshapes struct simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -61,7 +61,7 @@ int main(int argc, char const* argv[]) ...@@ -61,7 +61,7 @@ int main(int argc, char const* argv[])
{ {
// GPU target // GPU target
prog.compile(migraphx::gpu::target{}); prog.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m; migraphx::parameter_map m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes()) for(auto&& x : prog.get_parameter_shapes())
{ {
......
...@@ -124,7 +124,7 @@ int main(int argc, char const* argv[]) ...@@ -124,7 +124,7 @@ int main(int argc, char const* argv[])
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto* ptr = input.data(); auto* ptr = input.data();
migraphx::program::parameter_map m; migraphx::parameter_map m;
m["output"] = m["output"] =
migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output"))); migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++) for(int i = 0; i < 20; i++)
......
...@@ -68,6 +68,7 @@ struct onnx_parser ...@@ -68,6 +68,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
module* mm = prog.get_main_module();
bool is_pytorch = false; bool is_pytorch = false;
std::size_t default_dim_value = 1; std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
...@@ -266,11 +267,11 @@ struct onnx_parser ...@@ -266,11 +267,11 @@ struct onnx_parser
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, auto l = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]); args[1]);
return prog.add_instruction(make_op(op_name), args[0], l); return mm->add_instruction(make_op(op_name), args[0], l);
} }
return prog.add_instruction(make_op(op_name), args); return mm->add_instruction(make_op(op_name), args);
} }
else else
{ {
...@@ -318,14 +319,14 @@ struct onnx_parser ...@@ -318,14 +319,14 @@ struct onnx_parser
return out_lens; return out_lens;
} }
instruction_ref make_contiguous(instruction_ref ins) instruction_ref make_contiguous(instruction_ref ins) const
{ {
if(ins->get_shape().standard()) if(ins->get_shape().standard())
{ {
return ins; return ins;
} }
return prog.add_instruction(make_op("contiguous"), ins); return mm->add_instruction(make_op("contiguous"), ins);
} }
instruction_ref instruction_ref
...@@ -340,17 +341,17 @@ struct onnx_parser ...@@ -340,17 +341,17 @@ struct onnx_parser
auto l0 = arg0; auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens) if(arg0->get_shape().lens() != out_lens)
l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0); l0 = mm->add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = arg1; auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens) if(arg1->get_shape().lens() != out_lens)
l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1); l1 = mm->add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(make_op(name), l0, l1); return mm->add_instruction(make_op(name), l0, l1);
} }
else else
{ {
return prog.add_instruction(make_op(name), {arg0, arg1}); return mm->add_instruction(make_op(name), {arg0, arg1});
} }
} }
...@@ -368,7 +369,7 @@ struct onnx_parser ...@@ -368,7 +369,7 @@ struct onnx_parser
return this->make_contiguous(arg); return this->make_contiguous(arg);
}); });
} }
return prog.add_instruction(op, args); return mm->add_instruction(op, args);
}); });
} }
...@@ -391,14 +392,15 @@ struct onnx_parser ...@@ -391,14 +392,15 @@ struct onnx_parser
return output_vector; return output_vector;
} }
instruction_ref instruction_ref add_bias(const std::vector<instruction_ref>& args,
add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis) instruction_ref curr_ins,
uint64_t axis) const
{ {
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = auto bias_bcast =
prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]); mm->add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
return prog.add_instruction(make_op("add"), curr_ins, bias_bcast); return mm->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
} }
...@@ -422,7 +424,7 @@ struct onnx_parser ...@@ -422,7 +424,7 @@ struct onnx_parser
const std::vector<int64_t>& padding, const std::vector<int64_t>& padding,
value& v, value& v,
int count_include_pad = 0, int count_include_pad = 0,
float pad_val = 0) float pad_val = 0) const
{ {
size_t pad_ndims = padding.size() / 2; size_t pad_ndims = padding.size() / 2;
auto left_pad_it = padding.begin(); auto left_pad_it = padding.begin();
...@@ -435,7 +437,7 @@ struct onnx_parser ...@@ -435,7 +437,7 @@ struct onnx_parser
asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it); asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
// add right pads // add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end()); asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = prog.add_instruction(op::pad{asym_pads, pad_val}, ins); ins = mm->add_instruction(op::pad{asym_pads, pad_val}, ins);
} }
else else
{ {
...@@ -444,7 +446,7 @@ struct onnx_parser ...@@ -444,7 +446,7 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args) parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg; instruction_ref min_arg;
...@@ -469,44 +471,44 @@ struct onnx_parser ...@@ -469,44 +471,44 @@ struct onnx_parser
float min_val = parse_value(info.attributes.at("min")).at<float>(); float min_val = parse_value(info.attributes.at("min")).at<float>();
float max_val = parse_value(info.attributes.at("max")).at<float>(); float max_val = parse_value(info.attributes.at("max")).at<float>();
min_arg = prog.add_literal(min_val); min_arg = mm->add_literal(min_val);
max_arg = prog.add_literal(max_val); max_arg = mm->add_literal(max_val);
min_used = true; min_used = true;
max_used = true; max_used = true;
} }
if(min_used) if(min_used)
{ {
min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg); min_arg = mm->add_instruction(op::multibroadcast{input_lens}, min_arg);
} }
if(max_used) if(max_used)
{ {
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg); max_arg = mm->add_instruction(op::multibroadcast{input_lens}, max_arg);
} }
if(min_used and max_used) if(min_used and max_used)
{ {
return prog.add_instruction(make_op("clip"), args[0], min_arg, max_arg); return mm->add_instruction(make_op("clip"), args[0], min_arg, max_arg);
} }
else if(max_used) else if(max_used)
{ {
return prog.add_instruction(make_op("min"), args[0], max_arg); return mm->add_instruction(make_op("min"), args[0], max_arg);
} }
else if(min_used) else if(min_used)
{ {
return prog.add_instruction(make_op("max"), args[0], min_arg); return mm->add_instruction(make_op("max"), args[0], min_arg);
} }
else else
{ {
return prog.add_instruction(make_op("identity"), args[0]); return mm->add_instruction(make_op("identity"), args[0]);
} }
} }
instruction_ref parse_arg_op(const std::string&, instruction_ref parse_arg_op(const std::string&,
const std::string& op_name, const std::string& op_name,
node_info info, node_info info,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
int64_t axis = 0; int64_t axis = 0;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
...@@ -522,12 +524,12 @@ struct onnx_parser ...@@ -522,12 +524,12 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args)); auto ins = mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins); return mm->add_instruction(op::squeeze{{axis}}, ins);
} }
else else
{ {
return prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args)); return mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
} }
} }
...@@ -591,7 +593,7 @@ struct onnx_parser ...@@ -591,7 +593,7 @@ struct onnx_parser
{ {
*starts_it = idx; *starts_it = idx;
*ends_it = *starts_it + 1; *ends_it = *starts_it + 1;
slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input)); slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
} }
// when padding on the left side, the outermost pad should be at the beginning // when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end()); std::reverse(slices.begin(), slices.end());
...@@ -600,9 +602,9 @@ struct onnx_parser ...@@ -600,9 +602,9 @@ struct onnx_parser
{ {
*starts_it = *dims_it - idx - 1; *starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1; *ends_it = *starts_it + 1;
slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input)); slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
} }
input = prog.add_instruction(op::concat{axis}, slices); input = mm->add_instruction(op::concat{axis}, slices);
} }
return input; return input;
} }
...@@ -747,7 +749,7 @@ struct onnx_parser ...@@ -747,7 +749,7 @@ struct onnx_parser
recalc_conv_attributes(values, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values); op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]); auto l1 = mm->add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1); return add_bias(args, l1, 1);
} }
...@@ -821,7 +823,7 @@ struct onnx_parser ...@@ -821,7 +823,7 @@ struct onnx_parser
recalc_conv_attributes(values, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values); op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]); auto l1 = mm->add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens()); std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end()); std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding) if(asym_padding)
...@@ -839,7 +841,7 @@ struct onnx_parser ...@@ -839,7 +841,7 @@ struct onnx_parser
std::back_inserter(ends), std::back_inserter(ends),
[](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; }); [](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
l1 = prog.add_instruction(op::slice{axes, starts, ends}, l1); l1 = mm->add_instruction(op::slice{axes, starts, ends}, l1);
} }
if(contains(info.attributes, "output_padding")) if(contains(info.attributes, "output_padding"))
...@@ -850,7 +852,7 @@ struct onnx_parser ...@@ -850,7 +852,7 @@ struct onnx_parser
check_attr_sizes(kdims, check_attr_sizes(kdims,
output_padding.size() - non_kdims, output_padding.size() - non_kdims,
"PARSE_CONV_TRANSPOSE: inconsistent output padding"); "PARSE_CONV_TRANSPOSE: inconsistent output padding");
l1 = prog.add_instruction(op::pad{output_padding}, l1); l1 = mm->add_instruction(op::pad{output_padding}, l1);
} }
if(contains(info.attributes, "output_shape")) if(contains(info.attributes, "output_shape"))
...@@ -869,7 +871,7 @@ struct onnx_parser ...@@ -869,7 +871,7 @@ struct onnx_parser
curr_shape.begin(), curr_shape.begin(),
std::back_inserter(target_padding), std::back_inserter(target_padding),
[](auto out_dim, auto curr_dim) { return out_dim - curr_dim; }); [](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
l1 = prog.add_instruction(op::pad{target_padding}, l1); l1 = mm->add_instruction(op::pad{target_padding}, l1);
} }
} }
...@@ -1042,12 +1044,12 @@ struct onnx_parser ...@@ -1042,12 +1044,12 @@ struct onnx_parser
} }
} }
op.from_value(values); op.from_value(values);
auto l1 = prog.add_instruction(op, l0); auto l1 = mm->add_instruction(op, l0);
if(!slice_start.empty()) if(!slice_start.empty())
{ {
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
l1 = prog.add_instruction(op::slice{axes, slice_start, slice_end}, l1); l1 = mm->add_instruction(op::slice{axes, slice_start, slice_end}, l1);
} }
return l1; return l1;
...@@ -1069,7 +1071,7 @@ struct onnx_parser ...@@ -1069,7 +1071,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
return prog.add_instruction(op, make_contiguous(args[0])); return mm->add_instruction(op, make_contiguous(args[0]));
} }
static const auto& get_nearest_op(const std::string& mode) static const auto& get_nearest_op(const std::string& mode)
...@@ -1248,9 +1250,9 @@ struct onnx_parser ...@@ -1248,9 +1250,9 @@ struct onnx_parser
// reshape input to one-dimension // reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens}; shape ind_s{shape::int32_type, out_lens};
auto rsp = prog.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); auto rsp = mm->add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = prog.add_literal(literal(ind_s, ind)); auto ins_ind = mm->add_literal(literal(ind_s, ind));
return prog.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind); return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
} }
instruction_ref instruction_ref
...@@ -1281,7 +1283,7 @@ struct onnx_parser ...@@ -1281,7 +1283,7 @@ struct onnx_parser
int64_t data_elem_num = static_cast<int64_t>(data_s.elements()); int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
// reshape the input data as one dimension and used as input data // reshape the input data as one dimension and used as input data
// to the gather operator // to the gather operator
arg_data = prog.add_instruction(op::reshape{{data_elem_num}}, arg_data); arg_data = mm->add_instruction(op::reshape{{data_elem_num}}, arg_data);
std::size_t elem_num = ind_s.elements(); std::size_t elem_num = ind_s.elements();
std::vector<int> ind_index(elem_num); std::vector<int> ind_index(elem_num);
...@@ -1299,16 +1301,16 @@ struct onnx_parser ...@@ -1299,16 +1301,16 @@ struct onnx_parser
}); });
auto l_shape_idx = auto l_shape_idx =
prog.add_literal(literal(ind_s, data_indices.begin(), data_indices.end())); mm->add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); auto l_dim_idx = mm->add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); auto l_stride = mm->add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride); l_stride = mm->add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
auto dim_diff = prog.add_instruction(make_op("sub"), arg_ind, l_dim_idx); auto dim_diff = mm->add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = prog.add_instruction(make_op("mul"), dim_diff, l_stride); auto delta = mm->add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = prog.add_instruction(make_op("add"), l_shape_idx, delta); auto ind = mm->add_instruction(make_op("add"), l_shape_idx, delta);
op::gather op{0}; op::gather op{0};
return prog.add_instruction(op, arg_data, ind); return mm->add_instruction(op, arg_data, ind);
} }
instruction_ref instruction_ref
...@@ -1373,17 +1375,17 @@ struct onnx_parser ...@@ -1373,17 +1375,17 @@ struct onnx_parser
op.axes = axes; op.axes = axes;
} }
return prog.add_instruction(op, args[0]); return mm->add_instruction(op, args[0]);
} }
instruction_ref instruction_ref
parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&) parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&) const
{ {
literal v = parse_value(info.attributes.at("value")); literal v = parse_value(info.attributes.at("value"));
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return prog.add_literal(literal{}); return mm->add_literal(literal{});
} }
auto dim_size = info.attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
...@@ -1391,14 +1393,14 @@ struct onnx_parser ...@@ -1391,14 +1393,14 @@ struct onnx_parser
if(dim_size == 0) if(dim_size == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()}); return mm->add_literal(migraphx::literal{scalar_shape, v.data()});
} }
return prog.add_literal(v); return mm->add_literal(v);
} }
instruction_ref instruction_ref
parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args) parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
...@@ -1426,8 +1428,8 @@ struct onnx_parser ...@@ -1426,8 +1428,8 @@ struct onnx_parser
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f && args[2]->get_shape().elements() > 0) if(beta != 0.f && args[2]->get_shape().elements() > 0)
...@@ -1438,14 +1440,14 @@ struct onnx_parser ...@@ -1438,14 +1440,14 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]); l3 = mm->add_instruction(op::multibroadcast{out_lens}, args[2]);
} }
return prog.add_instruction( return mm->add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
} }
} }
return prog.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); return mm->add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
} }
instruction_ref parse_matmul(const std::string&, instruction_ref parse_matmul(const std::string&,
...@@ -1464,7 +1466,7 @@ struct onnx_parser ...@@ -1464,7 +1466,7 @@ struct onnx_parser
{ {
is_a_prepended = true; is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1); l0_lens.insert(l0_lens.begin(), 1);
l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]); l0 = mm->add_instruction(op::unsqueeze{{0}}, args[0]);
} }
bool is_b_appended = false; bool is_b_appended = false;
...@@ -1472,7 +1474,7 @@ struct onnx_parser ...@@ -1472,7 +1474,7 @@ struct onnx_parser
{ {
is_b_appended = true; is_b_appended = true;
l1_lens.push_back(1); l1_lens.push_back(1);
l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]); l1 = mm->add_instruction(op::unsqueeze{{1}}, args[1]);
} }
instruction_ref bl0 = l0; instruction_ref bl0 = l0;
...@@ -1490,32 +1492,31 @@ struct onnx_parser ...@@ -1490,32 +1492,31 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0); bl0 = mm->add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
} }
if(l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1); bl1 = mm->add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
} }
} }
auto dot_res = auto dot_res = mm->add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
prog.add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res); dot_res = mm->add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis; --num_axis;
} }
if(is_b_appended) if(is_b_appended)
{ {
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res); dot_res = mm->add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
} }
return dot_res; return dot_res;
} }
instruction_ref instruction_ref
parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args) parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
float epsilon = 1e-5f; float epsilon = 1e-5f;
float momentum = 0.9f; float momentum = 0.9f;
...@@ -1535,11 +1536,11 @@ struct onnx_parser ...@@ -1535,11 +1536,11 @@ struct onnx_parser
: op::batch_norm_inference::per_activation; : op::batch_norm_inference::per_activation;
} }
op::batch_norm_inference op{epsilon, momentum, bn_mode}; op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args)); return mm->add_instruction(op, std::move(args));
} }
instruction_ref instruction_ref
parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args) parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
...@@ -1561,26 +1562,26 @@ struct onnx_parser ...@@ -1561,26 +1562,26 @@ struct onnx_parser
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = prog.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean); auto mean_bcast = mm->add_instruction(op::multibroadcast{dims}, mean);
auto l0 = prog.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = mm->add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = prog.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = prog.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = mm->add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = prog.add_literal(epsilon); auto epsilon_literal = mm->add_literal(epsilon);
auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal); auto epsilon_bcast = mm->add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = prog.add_instruction(op::multibroadcast{dims}, variance); auto variance_bcast = mm->add_instruction(op::multibroadcast{dims}, variance);
auto l2 = prog.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); auto l2 = mm->add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = prog.add_instruction(make_op("rsqrt"), l2); auto l3 = mm->add_instruction(make_op("rsqrt"), l2);
auto l4 = prog.add_instruction(make_op("mul"), l1, l3); auto l4 = mm->add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = prog.add_instruction(op::broadcast{1, dims}, scale); auto scale_bcast = mm->add_instruction(op::broadcast{1, dims}, scale);
; ;
auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias); auto bias_bcast = mm->add_instruction(op::broadcast{1, dims}, bias);
auto l5 = prog.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = mm->add_instruction(make_op("mul"), l4, scale_bcast);
return prog.add_instruction(make_op("add"), l5, bias_bcast); return mm->add_instruction(make_op("add"), l5, bias_bcast);
} }
instruction_ref instruction_ref
parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args) parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
float alpha = 0.01; // default alpha val for leaky relu float alpha = 0.01; // default alpha val for leaky relu
if(contains(info.attributes, "alpha")) if(contains(info.attributes, "alpha"))
...@@ -1588,10 +1589,11 @@ struct onnx_parser ...@@ -1588,10 +1589,11 @@ struct onnx_parser
alpha = parse_value(info.attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
auto op = make_op("leaky_relu", {{"alpha", alpha}}); auto op = make_op("leaky_relu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front()); return mm->add_instruction(op, args.front());
} }
instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args) instruction_ref
parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
float alpha = 1.0; // default alpha val for elu float alpha = 1.0; // default alpha val for elu
if(contains(info.attributes, "alpha")) if(contains(info.attributes, "alpha"))
...@@ -1599,10 +1601,11 @@ struct onnx_parser ...@@ -1599,10 +1601,11 @@ struct onnx_parser
alpha = parse_value(info.attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
auto op = make_op("elu", {{"alpha", alpha}}); auto op = make_op("elu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front()); return mm->add_instruction(op, args.front());
} }
instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args) instruction_ref
parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
float alpha = 0.0001; float alpha = 0.0001;
float beta = 0.75; float beta = 0.75;
...@@ -1617,11 +1620,11 @@ struct onnx_parser ...@@ -1617,11 +1620,11 @@ struct onnx_parser
if(contains(info.attributes, "size")) if(contains(info.attributes, "size"))
size = parse_value(info.attributes.at("size")).at<int>(); size = parse_value(info.attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size}; op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front()); return mm->add_instruction(op, args.front());
} }
instruction_ref instruction_ref
parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args) parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
float scale = 1.0; float scale = 1.0;
std::vector<float> bias{}; std::vector<float> bias{};
...@@ -1639,18 +1642,17 @@ struct onnx_parser ...@@ -1639,18 +1642,17 @@ struct onnx_parser
auto const& input_lens = input_shape.lens(); auto const& input_lens = input_shape.lens();
auto input_type = input_shape.type(); auto input_type = input_shape.type();
auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}}); auto scale_val = mm->add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias}); auto bias_vals = mm->add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val); auto scale_tensor = mm->add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
prog.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor); auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals); return mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
return prog.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
} }
instruction_ref instruction_ref
parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args) parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
std::vector<int64_t> perm{}; std::vector<int64_t> perm{};
if(contains(info.attributes, "perm")) if(contains(info.attributes, "perm"))
...@@ -1658,7 +1660,7 @@ struct onnx_parser ...@@ -1658,7 +1660,7 @@ struct onnx_parser
auto&& perm_vals = info.attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return prog.add_instruction(migraphx::op::transpose{perm}, args.front()); return mm->add_instruction(migraphx::op::transpose{perm}, args.front());
} }
instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args) instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
...@@ -1683,7 +1685,7 @@ struct onnx_parser ...@@ -1683,7 +1685,7 @@ struct onnx_parser
// check if padding is actually being done (at least one value is nonzero) // check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{ {
return prog.add_instruction(make_op("identity"), args.front()); return mm->add_instruction(make_op("identity"), args.front());
} }
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
...@@ -1719,11 +1721,11 @@ struct onnx_parser ...@@ -1719,11 +1721,11 @@ struct onnx_parser
value = parse_value(info.attributes.at("value")).at<float>(); value = parse_value(info.attributes.at("value")).at<float>();
} }
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front()); return mm->add_instruction(migraphx::op::pad{pads, value}, args.front());
} }
instruction_ref instruction_ref
parse_selu(const std::string&, const node_info& info, std::vector<instruction_ref> args) parse_selu(const std::string&, const node_info& info, std::vector<instruction_ref> args) const
{ {
auto type = args[0]->get_shape().type(); auto type = args[0]->get_shape().type();
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
...@@ -1739,35 +1741,35 @@ struct onnx_parser ...@@ -1739,35 +1741,35 @@ struct onnx_parser
gamma = info.attributes.at("gamma").f(); gamma = info.attributes.at("gamma").f();
} }
auto l_alpha = prog.add_literal({{type, {1}}, {alpha}}); auto l_alpha = mm->add_literal({{type, {1}}, {alpha}});
auto l_gamma = prog.add_literal({{type, {1}}, {gamma / 2.0f}}); auto l_gamma = mm->add_literal({{type, {1}}, {gamma / 2.0f}});
if(lens != std::vector<std::size_t>{1}) if(lens != std::vector<std::size_t>{1})
{ {
l_alpha = l_alpha =
prog.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha); mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
l_gamma = l_gamma =
prog.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma); mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
} }
auto sign_x = prog.add_instruction(make_op("sign"), args[0]); auto sign_x = mm->add_instruction(make_op("sign"), args[0]);
auto exp_x = prog.add_instruction(make_op("exp"), args[0]); auto exp_x = mm->add_instruction(make_op("exp"), args[0]);
auto alpha_ex = prog.add_instruction(make_op("mul"), l_alpha, exp_x); auto alpha_ex = mm->add_instruction(make_op("mul"), l_alpha, exp_x);
auto aex_alpha = prog.add_instruction(make_op("sub"), alpha_ex, l_alpha); auto aex_alpha = mm->add_instruction(make_op("sub"), alpha_ex, l_alpha);
auto ins1 = prog.add_instruction(make_op("add"), aex_alpha, args[0]); auto ins1 = mm->add_instruction(make_op("add"), aex_alpha, args[0]);
auto ins2 = prog.add_instruction(make_op("sub"), aex_alpha, args[0]); auto ins2 = mm->add_instruction(make_op("sub"), aex_alpha, args[0]);
auto sign2 = prog.add_instruction(make_op("mul"), sign_x, ins2); auto sign2 = mm->add_instruction(make_op("mul"), sign_x, ins2);
auto ins_sub = prog.add_instruction(make_op("sub"), ins1, sign2); auto ins_sub = mm->add_instruction(make_op("sub"), ins1, sign2);
return prog.add_instruction(make_op("mul"), ins_sub, l_gamma); return mm->add_instruction(make_op("mul"), ins_sub, l_gamma);
} }
// Use a literal instruction to replace the shape since, output of // Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx // shape operator are literals in migraphx
instruction_ref instruction_ref
parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args) parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args) const
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
...@@ -1777,7 +1779,7 @@ struct onnx_parser ...@@ -1777,7 +1779,7 @@ struct onnx_parser
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i); return int64_t(i);
}); });
return prog.add_literal(migraphx::literal{s, vec_shape}); return mm->add_literal(migraphx::literal{s, vec_shape});
} }
// Use a literal instruction to replace the constantFill operator. In RNN, input shape // Use a literal instruction to replace the constantFill operator. In RNN, input shape
...@@ -1831,7 +1833,7 @@ struct onnx_parser ...@@ -1831,7 +1833,7 @@ struct onnx_parser
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims); migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value); std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values)); return mm->add_literal(migraphx::literal(s, values));
} }
else if(input_as_shape == 0) else if(input_as_shape == 0)
{ {
...@@ -1845,7 +1847,7 @@ struct onnx_parser ...@@ -1845,7 +1847,7 @@ struct onnx_parser
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims}; migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value); std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values)); return mm->add_literal(migraphx::literal(s, values));
} }
else else
{ {
...@@ -1903,7 +1905,7 @@ struct onnx_parser ...@@ -1903,7 +1905,7 @@ struct onnx_parser
l_out = literal(s, out_vec); l_out = literal(s, out_vec);
}); });
return prog.add_literal(l_out); return mm->add_literal(l_out);
} }
} }
...@@ -1916,7 +1918,7 @@ struct onnx_parser ...@@ -1916,7 +1918,7 @@ struct onnx_parser
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]); return mm->add_instruction(op::multibroadcast{out_lens}, args[0]);
} }
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -1999,16 +2001,16 @@ struct onnx_parser ...@@ -1999,16 +2001,16 @@ struct onnx_parser
// undefined operator to have 6 arguments // undefined operator to have 6 arguments
if(args.size() < 6) if(args.size() < 6)
{ {
auto ins = prog.add_instruction(op::undefined{}); auto ins = mm->add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins); args.insert(args.end(), (6 - args.size()), ins);
} }
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, auto hidden_states =
std::move(args)); mm->add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::move(args));
// second output for the last hidden state // second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states); auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
...@@ -2120,17 +2122,17 @@ struct onnx_parser ...@@ -2120,17 +2122,17 @@ struct onnx_parser
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 6) if(args.size() < 6)
{ {
auto ins = prog.add_instruction(op::undefined{}); auto ins = mm->add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins); args.insert(args.end(), 6 - args.size(), ins);
} }
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = prog.add_instruction( auto hidden_states = mm->add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset}, op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args)); std::move(args));
// second output for last gru output // second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states); auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
...@@ -2302,18 +2304,18 @@ struct onnx_parser ...@@ -2302,18 +2304,18 @@ struct onnx_parser
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
auto ins = prog.add_instruction(op::undefined{}); auto ins = mm->add_instruction(op::undefined{});
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = prog.add_instruction( auto hidden_states = mm->add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args)); op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states); auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
// third output for last cell output // third output for last cell output
auto last_cell_output = prog.add_instruction(op::rnn_last_cell_output{}, hidden_states); auto last_cell_output = mm->add_instruction(op::rnn_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
...@@ -2321,7 +2323,7 @@ struct onnx_parser ...@@ -2321,7 +2323,7 @@ struct onnx_parser
instruction_ref parse_reduce_oper(const std::string&, instruction_ref parse_reduce_oper(const std::string&,
const std::string& op_name, const std::string& op_name,
node_info info, node_info info,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); std::size_t n_dim = args.front()->get_shape().lens().size();
...@@ -2343,54 +2345,57 @@ struct onnx_parser ...@@ -2343,54 +2345,57 @@ struct onnx_parser
if(keep_dims == 1) if(keep_dims == 1)
{ {
return prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args)); return mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
} }
else else
{ {
auto ins = prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args)); auto ins = mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins); return mm->add_instruction(op::squeeze{axes}, ins);
} }
} }
instruction_ref instruction_ref
parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
auto abs_ins = prog.add_instruction(make_op("abs"), args[0]); auto abs_ins = mm->add_instruction(make_op("abs"), args[0]);
return parse_reduce_oper({}, "reduce_sum", std::move(info), {abs_ins}); return parse_reduce_oper({}, "reduce_sum", std::move(info), {abs_ins});
} }
instruction_ref instruction_ref
parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args) parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]); auto square_ins = mm->add_instruction(make_op("mul"), args[0], args[0]);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins}); auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
return prog.add_instruction(make_op("sqrt"), sum_ins); return mm->add_instruction(make_op("sqrt"), sum_ins);
} }
instruction_ref instruction_ref parse_reduce_log_sum(const std::string&,
parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args) node_info info,
std::vector<instruction_ref> args) const
{ {
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), std::move(args)); auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), std::move(args));
return prog.add_instruction(make_op("log"), sum_ins); return mm->add_instruction(make_op("log"), sum_ins);
} }
instruction_ref instruction_ref parse_reduce_log_sum_exp(const std::string&,
parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args) node_info info,
std::vector<instruction_ref> args) const
{ {
auto exp_ins = prog.add_instruction(make_op("exp"), args[0]); auto exp_ins = mm->add_instruction(make_op("exp"), args[0]);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {exp_ins}); auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {exp_ins});
return prog.add_instruction(make_op("log"), sum_ins); return mm->add_instruction(make_op("log"), sum_ins);
} }
instruction_ref instruction_ref parse_reduce_sum_square(const std::string&,
parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args) node_info info,
std::vector<instruction_ref> args) const
{ {
auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]); auto square_ins = mm->add_instruction(make_op("mul"), args[0], args[0]);
return parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins}); return parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
} }
instruction_ref instruction_ref
parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args) parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args) const
{ {
if(!contains(info.attributes, "to")) if(!contains(info.attributes, "to"))
{ {
...@@ -2399,7 +2404,7 @@ struct onnx_parser ...@@ -2399,7 +2404,7 @@ struct onnx_parser
int to_type = parse_value(info.attributes.at("to")).at<int>(); int to_type = parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type); shape::type_t type = get_type(to_type);
return prog.add_instruction(make_op("convert", {{"target_type", type}}), std::move(args)); return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
} }
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -2448,7 +2453,7 @@ struct onnx_parser ...@@ -2448,7 +2453,7 @@ struct onnx_parser
for(auto sl : vec_splits) for(auto sl : vec_splits)
{ {
ret_ins.push_back( ret_ins.push_back(
prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0])); mm->add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
start += sl; start += sl;
} }
...@@ -2476,8 +2481,8 @@ struct onnx_parser ...@@ -2476,8 +2481,8 @@ struct onnx_parser
auto type = args[2]->get_shape().type(); auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}}; shape s{type, {depth, depth}};
auto l_val = prog.add_literal({s, depth_input}); auto l_val = mm->add_literal({s, depth_input});
auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]}); auto gather_out = mm->add_instruction(op::gather{0}, {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim // Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size(); int n_rank = gather_out->get_shape().lens().size();
...@@ -2489,16 +2494,16 @@ struct onnx_parser ...@@ -2489,16 +2494,16 @@ struct onnx_parser
std::vector<int64_t> perm(n_rank - 1); std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1); perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out); auto tr_out = mm->add_instruction(op::transpose{perm}, gather_out);
auto lens = tr_out->get_shape().lens(); auto lens = tr_out->get_shape().lens();
auto off_val = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]); auto off_val = mm->add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]); auto on_val = mm->add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto diff = prog.add_instruction(make_op("sub"), on_val, off_val); auto diff = mm->add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = prog.add_instruction(op::multibroadcast{lens}, off_val); auto unsq_off_val = mm->add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff); auto unsq_diff_val = mm->add_instruction(op::multibroadcast{lens}, diff);
auto l_mul = prog.add_instruction(make_op("mul"), tr_out, unsq_diff_val); auto l_mul = mm->add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return prog.add_instruction(make_op("add"), l_mul, unsq_off_val); return mm->add_instruction(make_op("add"), l_mul, unsq_off_val);
} }
instruction_ref instruction_ref
...@@ -2515,7 +2520,7 @@ struct onnx_parser ...@@ -2515,7 +2520,7 @@ struct onnx_parser
auto l1 = l0; auto l1 = l0;
for(int j = 1; j < repeats[i]; j++) for(int j = 1; j < repeats[i]; j++)
{ {
l0 = prog.add_instruction(op::concat{i}, l0, l1); l0 = mm->add_instruction(op::concat{i}, l0, l1);
} }
} }
return l0; return l0;
...@@ -2557,7 +2562,7 @@ struct onnx_parser ...@@ -2557,7 +2562,7 @@ struct onnx_parser
return result; return result;
}); });
l0 = prog.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals}); l0 = mm->add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
}); });
return l0; return l0;
} }
...@@ -2569,7 +2574,8 @@ struct onnx_parser ...@@ -2569,7 +2574,8 @@ struct onnx_parser
max = 2 max = 2
}; };
instruction_ref parse_embedding_bag(const node_info& info, std::vector<instruction_ref> args) instruction_ref parse_embedding_bag(const node_info& info,
std::vector<instruction_ref> args) const
{ {
if(args[2]->get_shape().elements() != 1) if(args[2]->get_shape().elements() != 1)
MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1"); MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1");
...@@ -2579,24 +2585,24 @@ struct onnx_parser ...@@ -2579,24 +2585,24 @@ struct onnx_parser
reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i()); reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
} }
auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]); auto l0 = mm->add_instruction(op::gather{}, args[0], args[1]);
switch(reduce_mode) switch(reduce_mode)
{ {
case reduce_mode_t::sum: case reduce_mode_t::sum:
l0 = prog.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0); l0 = mm->add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
break; break;
case reduce_mode_t::mean: case reduce_mode_t::mean:
l0 = prog.add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0); l0 = mm->add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
break; break;
case reduce_mode_t::max: case reduce_mode_t::max:
l0 = prog.add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0); l0 = mm->add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
break; break;
} }
return l0; return l0;
} }
instruction_ref instruction_ref
parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args) parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args) const
{ {
if(contains(info.attributes, "operator")) if(contains(info.attributes, "operator"))
{ {
...@@ -2610,13 +2616,13 @@ struct onnx_parser ...@@ -2610,13 +2616,13 @@ struct onnx_parser
} }
std::vector<instruction_ref> std::vector<instruction_ref>
parse_dropout(const std::string&, const node_info&, std::vector<instruction_ref> args) parse_dropout(const std::string&, const node_info&, std::vector<instruction_ref> args) const
{ {
auto out = prog.add_instruction(make_op("identity"), args[0]); auto out = mm->add_instruction(make_op("identity"), args[0]);
auto s = args[0]->get_shape(); auto s = args[0]->get_shape();
std::vector<int8_t> vec(s.elements(), 1); std::vector<int8_t> vec(s.elements(), 1);
shape mask_s{shape::bool_type, s.lens()}; shape mask_s{shape::bool_type, s.lens()};
auto mask = prog.add_literal(literal(mask_s, vec)); auto mask = mm->add_literal(literal(mask_s, vec));
return {out, mask}; return {out, mask};
} }
...@@ -2661,7 +2667,7 @@ struct onnx_parser ...@@ -2661,7 +2667,7 @@ struct onnx_parser
} }
} }
return prog.add_literal(literal(out_s, out_data)); return mm->add_literal(literal(out_s, out_data));
} }
instruction_ref parse_compare_op(const std::string&, instruction_ref parse_compare_op(const std::string&,
...@@ -2672,7 +2678,7 @@ struct onnx_parser ...@@ -2672,7 +2678,7 @@ struct onnx_parser
auto l = add_broadcastable_binary_op(args[0], args[1], op_name); auto l = add_broadcastable_binary_op(args[0], args[1], op_name);
if(l->get_shape().type() != shape::bool_type) if(l->get_shape().type() != shape::bool_type)
{ {
l = prog.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l); l = mm->add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
} }
return l; return l;
} }
...@@ -2737,9 +2743,9 @@ struct onnx_parser ...@@ -2737,9 +2743,9 @@ struct onnx_parser
// reshape input to one-dimension // reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens}; shape ind_s{shape::int32_type, out_lens};
auto rsp = prog.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); auto rsp = mm->add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = prog.add_literal(literal(ind_s, ind)); auto ins_ind = mm->add_literal(literal(ind_s, ind));
return prog.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind); return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
} }
instruction_ref instruction_ref
...@@ -2748,7 +2754,7 @@ struct onnx_parser ...@@ -2748,7 +2754,7 @@ struct onnx_parser
auto type = args[1]->get_shape().type(); auto type = args[1]->get_shape().type();
// the operation of if cond == 1 select x; else select y, // the operation of if cond == 1 select x; else select y,
// is equivalent to cond * (x - y) + y // is equivalent to cond * (x - y) + y
auto cond = prog.add_instruction(make_op("convert", {{"target_type", type}}), args[0]); auto cond = mm->add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
auto diff = add_broadcastable_binary_op(args[1], args[2], "sub"); auto diff = add_broadcastable_binary_op(args[1], args[2], "sub");
auto cd = add_broadcastable_binary_op(diff, cond, "mul"); auto cd = add_broadcastable_binary_op(diff, cond, "mul");
return add_broadcastable_binary_op(cd, args[2], "add"); return add_broadcastable_binary_op(cd, args[2], "add");
...@@ -2795,7 +2801,7 @@ struct onnx_parser ...@@ -2795,7 +2801,7 @@ struct onnx_parser
{ {
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
instructions[f.name()] = prog.add_literal(parse_tensor(f)); instructions[f.name()] = mm->add_literal(parse_tensor(f));
} }
for(auto&& input : graph.input()) for(auto&& input : graph.input())
...@@ -2811,7 +2817,7 @@ struct onnx_parser ...@@ -2811,7 +2817,7 @@ struct onnx_parser
} }
shape s = parse_type(input.type(), dims); shape s = parse_type(input.type(), dims);
instructions[name] = prog.add_parameter(name, s); instructions[name] = mm->add_parameter(name, s);
} }
} }
...@@ -2837,7 +2843,7 @@ struct onnx_parser ...@@ -2837,7 +2843,7 @@ struct onnx_parser
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
if(skip_unknown_operators) if(skip_unknown_operators)
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args)); result.push_back(mm->add_instruction(op::unknown{node.op_type()}, args));
else else
MIGRAPHX_THROW("Unknown operator: " + node.op_type()); MIGRAPHX_THROW("Unknown operator: " + node.op_type());
} }
...@@ -2875,14 +2881,14 @@ struct onnx_parser ...@@ -2875,14 +2881,14 @@ struct onnx_parser
[&](const auto& name) { return instructions[name]; }); [&](const auto& name) { return instructions[name]; });
// add the return instuction // add the return instuction
prog.add_return(output_ins); mm->add_return(output_ins);
} }
void parse_undefined(const std::string& name) void parse_undefined(const std::string& name)
{ {
if(!contains(instructions, name)) if(!contains(instructions, name))
{ {
auto ins = prog.add_instruction(op::undefined{}); auto ins = mm->add_instruction(op::undefined{});
instructions[name] = ins; instructions[name] = ins;
} }
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const void memory_coloring::apply(module& p) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
......
...@@ -67,7 +67,7 @@ using interval_ptr = live_interval*; ...@@ -67,7 +67,7 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify) memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify) : p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{ {
instr2_live.clear(); instr2_live.clear();
...@@ -145,7 +145,7 @@ struct memory_coloring_impl ...@@ -145,7 +145,7 @@ struct memory_coloring_impl
return (i1->offset > i2->offset); return (i1->offset > i2->offset);
} }
}; };
program* p_program; module* p_program;
std::unordered_map<const instruction*, interval_ptr> instr2_live; std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals. // universe of live intervals.
std::vector<live_interval> live_intervals; std::vector<live_interval> live_intervals;
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) void run_passes(module& modl, const std::vector<pass>& passes, tracer trace)
{ {
for(const auto& p : passes) for(const auto& p : passes)
{ {
trace("Pass: ", p.name()); trace("Pass: ", p.name());
p.apply(prog); p.apply(modl);
trace(prog); trace(modl);
#ifndef NDEBUG #ifndef NDEBUG
trace("Validate ..."); trace("Validate ...");
auto invalid = prog.validate(); auto invalid = modl.validate();
if(invalid != prog.end()) if(invalid != modl.end())
{ {
auto index = std::distance(prog.begin(), invalid); auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name()); std::to_string(index) + ": " + invalid->name());
} }
......
...@@ -20,7 +20,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -20,7 +20,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
void propagate_constant::apply(program& p) const void propagate_constant::apply(module& p) const
{ {
for(auto i : iterator_for(p)) for(auto i : iterator_for(p))
{ {
......
...@@ -256,8 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -256,8 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module") py::class_<migraphx::module>(m, "module").def("print", [](const migraphx::module& mm) {
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; }); std::cout << mm << std::endl;
});
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); }) .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
...@@ -277,12 +278,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -277,12 +278,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("fast_math") = true) py::arg("fast_math") = true)
.def("get_main_module", .def("get_main_module",
[](migraphx::program& p) { [](migraphx::program& p) {
auto mm = p.get_main_module(); auto* mm = p.get_main_module();
return migraphx::module_wrap{mm}; return migraphx::module{*mm};
}) })
.def("run", .def("run",
[](migraphx::program& p, py::dict params) { [](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm; migraphx::parameter_map pm;
for(auto x : params) for(auto x : params)
{ {
std::string key = x.first.cast<std::string>(); std::string key = x.first.cast<std::string>();
...@@ -389,7 +390,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -389,7 +390,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
&migraphx::quantize_int8, &migraphx::quantize_int8,
py::arg("prog"), py::arg("prog"),
py::arg("t"), py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{}, py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"}); py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU #ifdef HAVE_GPU
......
...@@ -27,7 +27,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -27,7 +27,7 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog, instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins,
...@@ -59,11 +59,11 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -59,11 +59,11 @@ instruction_ref insert_quant_ins(program& prog,
if(scaled_ins->get_shape().type() != shape::float_type) if(scaled_ins->get_shape().type() != shape::float_type)
{ {
float_ins = float_ins =
prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins); modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
} }
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale); std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(float_ins->get_shape(), vec_scale)); auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins); scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
} }
auto shifted_ins = scaled_ins; auto shifted_ins = scaled_ins;
...@@ -72,27 +72,27 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -72,27 +72,27 @@ instruction_ref insert_quant_ins(program& prog,
auto float_ins = shifted_ins; auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type) if(shifted_ins->get_shape().type() != shape::float_type)
{ {
float_ins = prog.insert_instruction( float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins); insert_loc, op::convert{shape::float_type}, shifted_ins);
} }
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift); std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(float_ins->get_shape(), vec_shift)); auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins); shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
} }
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins); auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens(); auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = prog.add_literal(127.0f); auto max_clip = modl.add_literal(127.0f);
auto min_clip = prog.add_literal(-128.0f); auto min_clip = modl.add_literal(-128.0f);
max_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip); max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip); min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto clipped_ins = auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip); modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins); quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
} }
else else
{ {
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins); quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins);
} }
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
...@@ -107,8 +107,9 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -107,8 +107,9 @@ instruction_ref insert_quant_ins(program& prog,
// truncate of the input to get the fp16. // truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
auto* mm = prog.get_main_module();
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(*mm))
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
break; break;
...@@ -139,7 +140,7 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -139,7 +140,7 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
} }
else else
{ {
input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16); input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
} }
converted_inputs.push_back(input_fp16); converted_inputs.push_back(input_fp16);
} }
...@@ -162,18 +163,18 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -162,18 +163,18 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
// check the dead code case to avoid assert // check the dead code case to avoid assert
bool output_empty = ins->outputs().empty(); bool output_empty = ins->outputs().empty();
auto ins_orig_type = auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins); mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty) if(!output_empty)
{ {
prog.replace_instruction(ins, ins_orig_type); mm->replace_instruction(ins, ins_orig_type);
} }
} }
prog.replace_instruction(ins, op, converted_inputs); mm->replace_instruction(ins, op, converted_inputs);
} }
} }
static void ins_quantize_int8(program& prog, static void ins_quantize_int8(module& modl,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref>& converted_inputs, std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params) const std::vector<std::pair<float, float>>& ins_quant_params)
...@@ -195,14 +196,14 @@ static void ins_quantize_int8(program& prog, ...@@ -195,14 +196,14 @@ static void ins_quantize_int8(program& prog,
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta)); int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type) if(shape::int32_type == orig_type)
{ {
prog.replace_instruction( modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
} }
else else
{ {
auto quant_dot = prog.insert_instruction( auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot); modl.replace_instruction(ins, op::convert{orig_type}, quant_dot);
} }
} }
// either alpha or beta cannot be quantized because of too big // either alpha or beta cannot be quantized because of too big
...@@ -213,51 +214,51 @@ static void ins_quantize_int8(program& prog, ...@@ -213,51 +214,51 @@ static void ins_quantize_int8(program& prog,
{ {
converted_inputs.pop_back(); converted_inputs.pop_back();
} }
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot); auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape(); auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha); std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha = auto l_alpha =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha)); modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f) if(inputs.size() == 3 and dot_op.beta != 0.0f)
{ {
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot); auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta); std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta = auto l_beta =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta)); modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{}; instruction_ref beta_c{};
if(orig_type != shape::float_type) if(orig_type != shape::float_type)
{ {
auto fp32_c = auto fp32_c =
prog.insert_instruction(ins, op::convert{shape::float_type}, inputs.back()); modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, fp32_c); beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
} }
else else
{ {
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
} }
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c); modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
} }
else else
{ {
auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c); auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
prog.replace_instruction(ins, op::convert{orig_type}, f_res); modl.replace_instruction(ins, op::convert{orig_type}, f_res);
} }
} }
else else
{ {
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot); modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
} }
else else
{ {
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot); auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab); modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
} }
} }
} }
...@@ -274,7 +275,7 @@ static void ins_quantize_int8(program& prog, ...@@ -274,7 +275,7 @@ static void ins_quantize_int8(program& prog,
auto group = conv_op.group; auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first); auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = prog.insert_instruction( auto quant_conv = modl.insert_instruction(
ins, ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group}, op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs); converted_inputs);
...@@ -282,25 +283,25 @@ static void ins_quantize_int8(program& prog, ...@@ -282,25 +283,25 @@ static void ins_quantize_int8(program& prog,
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor); std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold) if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{ {
auto l_factor = prog.add_literal( auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end())); literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor); modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
} }
// convert quant_conv output to float type, multiply the factor and // convert quant_conv output to float type, multiply the factor and
// conver back to original type // conver back to original type
else else
{ {
auto float_conv = auto float_conv =
prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv); modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor)); auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv); modl.replace_instruction(ins, op::mul{}, l_factor, float_conv);
} }
else else
{ {
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv); auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv); modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
} }
} }
} }
...@@ -338,10 +339,11 @@ void quantize_int8_impl(program& prog, ...@@ -338,10 +339,11 @@ void quantize_int8_impl(program& prog,
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
} }
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0; std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins; std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index; std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(*mm))
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
break; break;
...@@ -398,7 +400,7 @@ void quantize_int8_impl(program& prog, ...@@ -398,7 +400,7 @@ void quantize_int8_impl(program& prog,
else else
{ {
quant_input = insert_quant_ins( quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second); *mm, input, quant_type, map_quant_ins, param.first, param.second);
} }
converted_inputs.push_back(quant_input); converted_inputs.push_back(quant_input);
} }
...@@ -414,7 +416,7 @@ void quantize_int8_impl(program& prog, ...@@ -414,7 +416,7 @@ void quantize_int8_impl(program& prog,
continue; continue;
} }
ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params); ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
} }
if(quant_param_index != quant_params.size()) if(quant_param_index != quant_params.size())
...@@ -425,7 +427,7 @@ void quantize_int8_impl(program& prog, ...@@ -425,7 +427,7 @@ void quantize_int8_impl(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<program::parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
// insert capture operator // insert capture operator
...@@ -439,7 +441,7 @@ void quantize_int8(program& prog, ...@@ -439,7 +441,7 @@ void quantize_int8(program& prog,
// quantization scale and shift // quantization scale and shift
for(auto&& arg : calibration) for(auto&& arg : calibration)
{ {
program::parameter_map m; parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes()) for(auto&& x : cap_prog.get_parameter_shapes())
{ {
if(arg.count(x.first) > 0) if(arg.count(x.first) > 0)
...@@ -464,7 +466,7 @@ std::size_t capture_arguments(program& prog, ...@@ -464,7 +466,7 @@ std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func) const std::function<void(std::size_t, std::vector<argument>)>& func)
{ {
auto* mm = prog.get_main_module();
size_t num_quant_params = 0; size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution // the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"}; std::set<std::string> op_names = {"dot", "convolution"};
...@@ -476,7 +478,7 @@ std::size_t capture_arguments(program& prog, ...@@ -476,7 +478,7 @@ std::size_t capture_arguments(program& prog,
} }
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(*mm))
{ {
if(not contains(ins_names, ins->name())) if(not contains(ins_names, ins->name()))
{ {
...@@ -494,7 +496,7 @@ std::size_t capture_arguments(program& prog, ...@@ -494,7 +496,7 @@ std::size_t capture_arguments(program& prog,
} }
else else
{ {
new_ins = prog.insert_instruction( new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input); std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins; ins_map[input] = new_ins;
} }
......
...@@ -22,7 +22,7 @@ struct find_dot_add ...@@ -22,7 +22,7 @@ struct find_dot_add
match::name("dot")(match::nargs(2)).bind("dot")))); match::name("dot")(match::nargs(2)).bind("dot"))));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot_ins = r.instructions["dot"]; auto dot_ins = r.instructions["dot"];
...@@ -36,7 +36,7 @@ struct find_dot_add ...@@ -36,7 +36,7 @@ struct find_dot_add
}; };
} // namespace } // namespace
void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); } void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(program& p) const void rewrite_batchnorm::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const void rewrite_pooling::apply(module& prog) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const void rewrite_rnn::apply(module& prog) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
...@@ -47,13 +47,13 @@ void rewrite_rnn::apply(program& prog) const ...@@ -47,13 +47,13 @@ void rewrite_rnn::apply(program& prog) const
} }
} }
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
{ {
assert(ins->name() == "rnn"); assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will // could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing // append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments // an onnx file. Another case is user can have num of arguments
// when writing their program. // when writing their module.
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
...@@ -210,7 +210,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -210,7 +210,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const operation& actv_func) const
...@@ -336,7 +336,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -336,7 +336,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
} }
} }
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
{ {
assert(ins->name() == "gru"); assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins); const auto actv_funcs = gru_actv_funcs(ins);
...@@ -502,7 +502,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -502,7 +502,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -685,7 +685,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -685,7 +685,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
} }
// for lstm operators // for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{ {
assert(ins->name() == "lstm"); assert(ins->name() == "lstm");
auto args = ins->inputs(); auto args = ins->inputs();
...@@ -927,7 +927,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -927,7 +927,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -1158,7 +1158,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1158,7 +1158,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
{ {
bool is_var_lens = false; bool is_var_lens = false;
if(seq_lens != prog.end()) if(seq_lens != prog.end())
...@@ -1188,7 +1188,7 @@ bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_ ...@@ -1188,7 +1188,7 @@ bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_
} }
std::size_t std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
{ {
bool is_var_lens = is_variable_seq_lens(prog, seq_lens); bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
auto input_shape = input->get_shape(); auto input_shape = input->get_shape();
...@@ -1204,7 +1204,7 @@ rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction ...@@ -1204,7 +1204,7 @@ rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction
return length; return length;
} }
instruction_ref rewrite_rnn::replace_last_hs_output(program& prog, instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
...@@ -1243,7 +1243,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog, ...@@ -1243,7 +1243,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
return result_ins; return result_ins;
} }
void rewrite_rnn::replace_last_cell_output(program& prog, void rewrite_rnn::replace_last_cell_output(module& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
...@@ -1281,7 +1281,7 @@ void rewrite_rnn::replace_last_cell_output(program& prog, ...@@ -1281,7 +1281,7 @@ void rewrite_rnn::replace_last_cell_output(program& prog,
} }
} }
instruction_ref rewrite_rnn::pad_hidden_states(program& prog, instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const instruction_ref hs) const
......
...@@ -103,7 +103,7 @@ struct stream_info ...@@ -103,7 +103,7 @@ struct stream_info
} }
}; };
std::size_t assign_streams(program& p, std::size_t n) std::size_t assign_streams(module& p, std::size_t n)
{ {
assert(n > 0); assert(n > 0);
partition critical; partition critical;
...@@ -182,7 +182,7 @@ struct stream_info ...@@ -182,7 +182,7 @@ struct stream_info
} }
}; };
void sort(program& p, std::size_t) const void sort(module& p, std::size_t)
{ {
std::set<weight_ins, compare_weight_ins> children; std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited; std::unordered_map<instruction_ref, std::size_t> visited;
...@@ -335,7 +335,7 @@ struct stream_info ...@@ -335,7 +335,7 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(program& p) const find_concurrent_instructions(module& p) const
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
...@@ -378,7 +378,7 @@ struct stream_info ...@@ -378,7 +378,7 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p) get_conflicts(module& p)
{ {
using conflict_table_type = using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
...@@ -464,7 +464,7 @@ struct stream_info ...@@ -464,7 +464,7 @@ struct stream_info
} }
}; };
void schedule::apply(program& p) const void schedule::apply(module& p) const
{ {
if(not enable) if(not enable)
return; return;
......
...@@ -50,7 +50,7 @@ struct find_mul_conv ...@@ -50,7 +50,7 @@ struct find_mul_conv
match::name("broadcast").bind("a"))); match::name("broadcast").bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
...@@ -86,7 +86,7 @@ struct find_mul_slice_conv ...@@ -86,7 +86,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a"))); match::name("broadcast")(match::is_constant()).bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto slice_ins = r.instructions["slice"]; auto slice_ins = r.instructions["slice"];
...@@ -169,7 +169,7 @@ struct find_mul_add ...@@ -169,7 +169,7 @@ struct find_mul_add
match::is_constant().bind("a"))); match::is_constant().bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
...@@ -191,7 +191,7 @@ struct find_add_lit_broadcast ...@@ -191,7 +191,7 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -211,7 +211,7 @@ struct find_double_add_lit_broadcast ...@@ -211,7 +211,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y"))); match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -249,7 +249,7 @@ struct find_inner_broadcast ...@@ -249,7 +249,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -294,7 +294,7 @@ struct find_concat_op ...@@ -294,7 +294,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise"); return op.name() == "broadcast" or op.attributes().contains("pointwise");
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis; auto axis = any_cast<op::concat>(ins->get_operator()).axis;
...@@ -425,7 +425,7 @@ struct find_splits ...@@ -425,7 +425,7 @@ struct find_splits
return groups; return groups;
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -520,7 +520,7 @@ struct find_split_concat ...@@ -520,7 +520,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat"))))); match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -618,7 +618,7 @@ struct find_add_convs ...@@ -618,7 +618,7 @@ struct find_add_convs
input.strides()[3] * n}}; input.strides()[3] * n}};
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_conv = r.instructions["a"]; auto a_conv = r.instructions["a"];
...@@ -689,7 +689,7 @@ struct find_conv_dot_horiz_fusion ...@@ -689,7 +689,7 @@ struct find_conv_dot_horiz_fusion
{ {
auto matcher() const { return horiz_conv_dot(); } auto matcher() const { return horiz_conv_dot(); }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -762,7 +762,7 @@ struct find_div_const ...@@ -762,7 +762,7 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c"))); return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
...@@ -782,7 +782,7 @@ struct find_sub_const ...@@ -782,7 +782,7 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c"))); return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
...@@ -803,7 +803,7 @@ struct find_rsqrt ...@@ -803,7 +803,7 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x"))))); match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -828,7 +828,7 @@ struct find_split_reshape ...@@ -828,7 +828,7 @@ struct find_split_reshape
.bind("reshape"); .bind("reshape");
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
...@@ -904,7 +904,7 @@ struct find_split_transpose ...@@ -904,7 +904,7 @@ struct find_split_transpose
.bind("trans"); .bind("trans");
} }
void apply(program& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"]; auto trans = r.instructions["trans"];
...@@ -949,7 +949,7 @@ struct find_split_transpose ...@@ -949,7 +949,7 @@ struct find_split_transpose
} }
}; };
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(module& p) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
......
...@@ -66,7 +66,7 @@ struct find_reshaper ...@@ -66,7 +66,7 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names()))); match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
...@@ -113,7 +113,7 @@ struct find_nop_reshapes ...@@ -113,7 +113,7 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front()); p.replace_instruction(ins, ins->inputs().front());
...@@ -128,7 +128,7 @@ struct find_transpose ...@@ -128,7 +128,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
...@@ -201,7 +201,7 @@ struct find_nested_slice ...@@ -201,7 +201,7 @@ struct find_nested_slice
return result; return result;
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto slice = ins->inputs().front(); auto slice = ins->inputs().front();
...@@ -230,7 +230,7 @@ struct find_concat_transpose ...@@ -230,7 +230,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto trans_inputs = ins->inputs(); auto trans_inputs = ins->inputs();
...@@ -279,7 +279,7 @@ struct find_nested_concat ...@@ -279,7 +279,7 @@ struct find_nested_concat
return op.axis; return op.axis;
} }
void apply(program& p, const match::matcher_result& mr) const void apply(module& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto axis = get_axis(ins); auto axis = get_axis(ins);
...@@ -298,7 +298,7 @@ struct find_nested_concat ...@@ -298,7 +298,7 @@ struct find_nested_concat
} }
}; };
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(module& p) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment