Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
......@@ -10,6 +10,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Schedule instructions for concurrent execution
......@@ -19,7 +20,7 @@ struct schedule
schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,6 +16,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct operation;
#ifdef DOXYGEN
......@@ -26,11 +27,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const;
void sched(module& p, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
......@@ -43,9 +44,9 @@ struct schedule_model
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const;
* void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
......@@ -120,19 +121,19 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency();
}
void sched(program& p, instruction_ref ins, std::size_t n) const
void sched(module& p, instruction_ref ins, std::size_t n) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n);
}
void wait(program& p, instruction_ref ins, std::size_t wait_id) const
void wait(module& p, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id);
}
void record(program& p, instruction_ref ins, std::size_t wait_id) const
void record(module& p, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id);
......@@ -158,11 +159,11 @@ struct schedule_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -195,19 +196,19 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override
void sched(module& p, instruction_ref ins, std::size_t n) const override
{
private_detail_te_value.sched(p, ins, n);
}
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override
void wait(module& p, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.wait(p, ins, wait_id);
}
void record(program& p, instruction_ref ins, std::size_t wait_id) const override
void record(module& p, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.record(p, ins, wait_id);
......
......@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Simplify many algebraic instructions to more efficient versions.
......@@ -15,7 +16,7 @@ struct program;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Eliminate redundant reshapes.
......@@ -16,7 +17,7 @@ struct program;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -61,7 +61,7 @@ int main(int argc, char const* argv[])
{
// GPU target
prog.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
migraphx::parameter_map m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes())
{
......
......@@ -124,7 +124,7 @@ int main(int argc, char const* argv[])
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl;
auto* ptr = input.data();
migraphx::program::parameter_map m;
migraphx::parameter_map m;
m["output"] =
migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++)
......
......@@ -68,6 +68,7 @@ struct onnx_parser
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
module* mm = prog.get_main_module();
bool is_pytorch = false;
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
......@@ -266,11 +267,11 @@ struct onnx_parser
if(broadcasted != 0)
{
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]);
return prog.add_instruction(make_op(op_name), args[0], l);
auto l = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]);
return mm->add_instruction(make_op(op_name), args[0], l);
}
return prog.add_instruction(make_op(op_name), args);
return mm->add_instruction(make_op(op_name), args);
}
else
{
......@@ -318,14 +319,14 @@ struct onnx_parser
return out_lens;
}
instruction_ref make_contiguous(instruction_ref ins)
instruction_ref make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(make_op("contiguous"), ins);
return mm->add_instruction(make_op("contiguous"), ins);
}
instruction_ref
......@@ -340,17 +341,17 @@ struct onnx_parser
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
l0 = mm->add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
l1 = mm->add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(make_op(name), l0, l1);
return mm->add_instruction(make_op(name), l0, l1);
}
else
{
return prog.add_instruction(make_op(name), {arg0, arg1});
return mm->add_instruction(make_op(name), {arg0, arg1});
}
}
......@@ -368,7 +369,7 @@ struct onnx_parser
return this->make_contiguous(arg);
});
}
return prog.add_instruction(op, args);
return mm->add_instruction(op, args);
});
}
......@@ -391,14 +392,15 @@ struct onnx_parser
return output_vector;
}
instruction_ref
add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins,
uint64_t axis) const
{
if(args.size() == 3)
{
auto bias_bcast =
prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
return prog.add_instruction(make_op("add"), curr_ins, bias_bcast);
mm->add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
return mm->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
}
......@@ -422,7 +424,7 @@ struct onnx_parser
const std::vector<int64_t>& padding,
value& v,
int count_include_pad = 0,
float pad_val = 0)
float pad_val = 0) const
{
size_t pad_ndims = padding.size() / 2;
auto left_pad_it = padding.begin();
......@@ -435,7 +437,7 @@ struct onnx_parser
asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
// add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = prog.add_instruction(op::pad{asym_pads, pad_val}, ins);
ins = mm->add_instruction(op::pad{asym_pads, pad_val}, ins);
}
else
{
......@@ -444,7 +446,7 @@ struct onnx_parser
}
instruction_ref
parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg;
......@@ -469,44 +471,44 @@ struct onnx_parser
float min_val = parse_value(info.attributes.at("min")).at<float>();
float max_val = parse_value(info.attributes.at("max")).at<float>();
min_arg = prog.add_literal(min_val);
max_arg = prog.add_literal(max_val);
min_arg = mm->add_literal(min_val);
max_arg = mm->add_literal(max_val);
min_used = true;
max_used = true;
}
if(min_used)
{
min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);
min_arg = mm->add_instruction(op::multibroadcast{input_lens}, min_arg);
}
if(max_used)
{
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);
max_arg = mm->add_instruction(op::multibroadcast{input_lens}, max_arg);
}
if(min_used and max_used)
{
return prog.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
return mm->add_instruction(make_op("clip"), args[0], min_arg, max_arg);
}
else if(max_used)
{
return prog.add_instruction(make_op("min"), args[0], max_arg);
return mm->add_instruction(make_op("min"), args[0], max_arg);
}
else if(min_used)
{
return prog.add_instruction(make_op("max"), args[0], min_arg);
return mm->add_instruction(make_op("max"), args[0], min_arg);
}
else
{
return prog.add_instruction(make_op("identity"), args[0]);
return mm->add_instruction(make_op("identity"), args[0]);
}
}
instruction_ref parse_arg_op(const std::string&,
const std::string& op_name,
node_info info,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
......@@ -522,12 +524,12 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
auto ins = mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return mm->add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
}
}
......@@ -591,7 +593,7 @@ struct onnx_parser
{
*starts_it = idx;
*ends_it = *starts_it + 1;
slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input));
slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
}
// when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end());
......@@ -600,9 +602,9 @@ struct onnx_parser
{
*starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1;
slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input));
slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
}
input = prog.add_instruction(op::concat{axis}, slices);
input = mm->add_instruction(op::concat{axis}, slices);
}
return input;
}
......@@ -747,7 +749,7 @@ struct onnx_parser
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]);
auto l1 = mm->add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1);
}
......@@ -821,7 +823,7 @@ struct onnx_parser
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = prog.add_instruction(op, l0, args[1]);
auto l1 = mm->add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding)
......@@ -839,7 +841,7 @@ struct onnx_parser
std::back_inserter(ends),
[](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
l1 = prog.add_instruction(op::slice{axes, starts, ends}, l1);
l1 = mm->add_instruction(op::slice{axes, starts, ends}, l1);
}
if(contains(info.attributes, "output_padding"))
......@@ -850,7 +852,7 @@ struct onnx_parser
check_attr_sizes(kdims,
output_padding.size() - non_kdims,
"PARSE_CONV_TRANSPOSE: inconsistent output padding");
l1 = prog.add_instruction(op::pad{output_padding}, l1);
l1 = mm->add_instruction(op::pad{output_padding}, l1);
}
if(contains(info.attributes, "output_shape"))
......@@ -869,7 +871,7 @@ struct onnx_parser
curr_shape.begin(),
std::back_inserter(target_padding),
[](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
l1 = prog.add_instruction(op::pad{target_padding}, l1);
l1 = mm->add_instruction(op::pad{target_padding}, l1);
}
}
......@@ -1042,12 +1044,12 @@ struct onnx_parser
}
}
op.from_value(values);
auto l1 = prog.add_instruction(op, l0);
auto l1 = mm->add_instruction(op, l0);
if(!slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = prog.add_instruction(op::slice{axes, slice_start, slice_end}, l1);
l1 = mm->add_instruction(op::slice{axes, slice_start, slice_end}, l1);
}
return l1;
......@@ -1069,7 +1071,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, make_contiguous(args[0]));
return mm->add_instruction(op, make_contiguous(args[0]));
}
static const auto& get_nearest_op(const std::string& mode)
......@@ -1248,9 +1250,9 @@ struct onnx_parser
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = prog.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = prog.add_literal(literal(ind_s, ind));
return prog.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
auto rsp = mm->add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = mm->add_literal(literal(ind_s, ind));
return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
instruction_ref
......@@ -1281,7 +1283,7 @@ struct onnx_parser
int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
// reshape the input data as one dimension and used as input data
// to the gather operator
arg_data = prog.add_instruction(op::reshape{{data_elem_num}}, arg_data);
arg_data = mm->add_instruction(op::reshape{{data_elem_num}}, arg_data);
std::size_t elem_num = ind_s.elements();
std::vector<int> ind_index(elem_num);
......@@ -1299,16 +1301,16 @@ struct onnx_parser
});
auto l_shape_idx =
prog.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
auto dim_diff = prog.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = prog.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = prog.add_instruction(make_op("add"), l_shape_idx, delta);
mm->add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = mm->add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = mm->add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = mm->add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
auto dim_diff = mm->add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = mm->add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = mm->add_instruction(make_op("add"), l_shape_idx, delta);
op::gather op{0};
return prog.add_instruction(op, arg_data, ind);
return mm->add_instruction(op, arg_data, ind);
}
instruction_ref
......@@ -1373,17 +1375,17 @@ struct onnx_parser
op.axes = axes;
}
return prog.add_instruction(op, args[0]);
return mm->add_instruction(op, args[0]);
}
instruction_ref
parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&) const
{
literal v = parse_value(info.attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
return mm->add_literal(literal{});
}
auto dim_size = info.attributes.at("value").t().dims_size();
......@@ -1391,14 +1393,14 @@ struct onnx_parser
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
return mm->add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v);
return mm->add_literal(v);
}
instruction_ref
parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
float alpha = 1.0f;
float beta = 1.0f;
......@@ -1426,8 +1428,8 @@ struct onnx_parser
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
......@@ -1438,14 +1440,14 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
l3 = mm->add_instruction(op::multibroadcast{out_lens}, args[2]);
}
return prog.add_instruction(
return mm->add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
}
}
return prog.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
return mm->add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
}
instruction_ref parse_matmul(const std::string&,
......@@ -1464,7 +1466,7 @@ struct onnx_parser
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
l0 = mm->add_instruction(op::unsqueeze{{0}}, args[0]);
}
bool is_b_appended = false;
......@@ -1472,7 +1474,7 @@ struct onnx_parser
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
l1 = mm->add_instruction(op::unsqueeze{{1}}, args[1]);
}
instruction_ref bl0 = l0;
......@@ -1490,32 +1492,31 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
bl0 = mm->add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
bl1 = mm->add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
auto dot_res =
prog.add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
auto dot_res = mm->add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
dot_res = mm->add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis;
}
if(is_b_appended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
dot_res = mm->add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
}
return dot_res;
}
instruction_ref
parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
......@@ -1535,11 +1536,11 @@ struct onnx_parser
: op::batch_norm_inference::per_activation;
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
return mm->add_instruction(op, std::move(args));
}
instruction_ref
parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
......@@ -1561,26 +1562,26 @@ struct onnx_parser
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
auto mean = prog.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean);
auto l0 = prog.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = prog.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = prog.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = prog.add_literal(epsilon);
auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = prog.add_instruction(op::multibroadcast{dims}, variance);
auto l2 = prog.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = prog.add_instruction(make_op("rsqrt"), l2);
auto l4 = prog.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = prog.add_instruction(op::broadcast{1, dims}, scale);
auto mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = mm->add_instruction(op::multibroadcast{dims}, mean);
auto l0 = mm->add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = mm->add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = mm->add_literal(epsilon);
auto epsilon_bcast = mm->add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = mm->add_instruction(op::multibroadcast{dims}, variance);
auto l2 = mm->add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = mm->add_instruction(op::broadcast{1, dims}, scale);
;
auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
auto l5 = prog.add_instruction(make_op("mul"), l4, scale_bcast);
return prog.add_instruction(make_op("add"), l5, bias_bcast);
auto bias_bcast = mm->add_instruction(op::broadcast{1, dims}, bias);
auto l5 = mm->add_instruction(make_op("mul"), l4, scale_bcast);
return mm->add_instruction(make_op("add"), l5, bias_bcast);
}
instruction_ref
parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
float alpha = 0.01; // default alpha val for leaky relu
if(contains(info.attributes, "alpha"))
......@@ -1588,10 +1589,11 @@ struct onnx_parser
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
auto op = make_op("leaky_relu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front());
return mm->add_instruction(op, args.front());
}
instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref
parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
float alpha = 1.0; // default alpha val for elu
if(contains(info.attributes, "alpha"))
......@@ -1599,10 +1601,11 @@ struct onnx_parser
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
auto op = make_op("elu", {{"alpha", alpha}});
return prog.add_instruction(op, args.front());
return mm->add_instruction(op, args.front());
}
instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref
parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
float alpha = 0.0001;
float beta = 0.75;
......@@ -1617,11 +1620,11 @@ struct onnx_parser
if(contains(info.attributes, "size"))
size = parse_value(info.attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
return mm->add_instruction(op, args.front());
}
instruction_ref
parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
float scale = 1.0;
std::vector<float> bias{};
......@@ -1639,18 +1642,17 @@ struct onnx_parser
auto const& input_lens = input_shape.lens();
auto input_type = input_shape.type();
auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_val = mm->add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = mm->add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled =
prog.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
auto scale_tensor = mm->add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
instruction_ref
parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
std::vector<int64_t> perm{};
if(contains(info.attributes, "perm"))
......@@ -1658,7 +1660,7 @@ struct onnx_parser
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
return mm->add_instruction(migraphx::op::transpose{perm}, args.front());
}
instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
......@@ -1683,7 +1685,7 @@ struct onnx_parser
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return prog.add_instruction(make_op("identity"), args.front());
return mm->add_instruction(make_op("identity"), args.front());
}
if(contains(info.attributes, "mode"))
......@@ -1719,11 +1721,11 @@ struct onnx_parser
value = parse_value(info.attributes.at("value")).at<float>();
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
return mm->add_instruction(migraphx::op::pad{pads, value}, args.front());
}
instruction_ref
parse_selu(const std::string&, const node_info& info, std::vector<instruction_ref> args)
parse_selu(const std::string&, const node_info& info, std::vector<instruction_ref> args) const
{
auto type = args[0]->get_shape().type();
auto lens = args[0]->get_shape().lens();
......@@ -1739,35 +1741,35 @@ struct onnx_parser
gamma = info.attributes.at("gamma").f();
}
auto l_alpha = prog.add_literal({{type, {1}}, {alpha}});
auto l_gamma = prog.add_literal({{type, {1}}, {gamma / 2.0f}});
auto l_alpha = mm->add_literal({{type, {1}}, {alpha}});
auto l_gamma = mm->add_literal({{type, {1}}, {gamma / 2.0f}});
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
prog.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
l_gamma =
prog.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
}
auto sign_x = prog.add_instruction(make_op("sign"), args[0]);
auto exp_x = prog.add_instruction(make_op("exp"), args[0]);
auto sign_x = mm->add_instruction(make_op("sign"), args[0]);
auto exp_x = mm->add_instruction(make_op("exp"), args[0]);
auto alpha_ex = prog.add_instruction(make_op("mul"), l_alpha, exp_x);
auto aex_alpha = prog.add_instruction(make_op("sub"), alpha_ex, l_alpha);
auto alpha_ex = mm->add_instruction(make_op("mul"), l_alpha, exp_x);
auto aex_alpha = mm->add_instruction(make_op("sub"), alpha_ex, l_alpha);
auto ins1 = prog.add_instruction(make_op("add"), aex_alpha, args[0]);
auto ins2 = prog.add_instruction(make_op("sub"), aex_alpha, args[0]);
auto ins1 = mm->add_instruction(make_op("add"), aex_alpha, args[0]);
auto ins2 = mm->add_instruction(make_op("sub"), aex_alpha, args[0]);
auto sign2 = prog.add_instruction(make_op("mul"), sign_x, ins2);
auto ins_sub = prog.add_instruction(make_op("sub"), ins1, sign2);
auto sign2 = mm->add_instruction(make_op("mul"), sign_x, ins2);
auto ins_sub = mm->add_instruction(make_op("sub"), ins1, sign2);
return prog.add_instruction(make_op("mul"), ins_sub, l_gamma);
return mm->add_instruction(make_op("mul"), ins_sub, l_gamma);
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args) const
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
......@@ -1777,7 +1779,7 @@ struct onnx_parser
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
return mm->add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
......@@ -1831,7 +1833,7 @@ struct onnx_parser
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
return mm->add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
......@@ -1845,7 +1847,7 @@ struct onnx_parser
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
return mm->add_literal(migraphx::literal(s, values));
}
else
{
......@@ -1903,7 +1905,7 @@ struct onnx_parser
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
return mm->add_literal(l_out);
}
}
......@@ -1916,7 +1918,7 @@ struct onnx_parser
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
return mm->add_instruction(op::multibroadcast{out_lens}, args[0]);
}
std::vector<instruction_ref>
......@@ -1999,16 +2001,16 @@ struct onnx_parser
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
auto ins = mm->add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
auto hidden_states =
mm->add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output};
}
......@@ -2120,17 +2122,17 @@ struct onnx_parser
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
auto ins = mm->add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
auto hidden_states = mm->add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output};
}
......@@ -2302,18 +2304,18 @@ struct onnx_parser
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
auto ins = mm->add_instruction(op::undefined{});
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
auto hidden_states = mm->add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::rnn_last_cell_output{}, hidden_states);
auto last_cell_output = mm->add_instruction(op::rnn_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
......@@ -2321,7 +2323,7 @@ struct onnx_parser
instruction_ref parse_reduce_oper(const std::string&,
const std::string& op_name,
node_info info,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
std::size_t n_dim = args.front()->get_shape().lens().size();
......@@ -2343,54 +2345,57 @@ struct onnx_parser
if(keep_dims == 1)
{
return prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
}
else
{
auto ins = prog.add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
auto ins = mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return mm->add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref
parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
auto abs_ins = prog.add_instruction(make_op("abs"), args[0]);
auto abs_ins = mm->add_instruction(make_op("abs"), args[0]);
return parse_reduce_oper({}, "reduce_sum", std::move(info), {abs_ins});
}
instruction_ref
parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]);
auto square_ins = mm->add_instruction(make_op("mul"), args[0], args[0]);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
return prog.add_instruction(make_op("sqrt"), sum_ins);
return mm->add_instruction(make_op("sqrt"), sum_ins);
}
instruction_ref
parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref parse_reduce_log_sum(const std::string&,
node_info info,
std::vector<instruction_ref> args) const
{
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), std::move(args));
return prog.add_instruction(make_op("log"), sum_ins);
return mm->add_instruction(make_op("log"), sum_ins);
}
instruction_ref
parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref parse_reduce_log_sum_exp(const std::string&,
node_info info,
std::vector<instruction_ref> args) const
{
auto exp_ins = prog.add_instruction(make_op("exp"), args[0]);
auto exp_ins = mm->add_instruction(make_op("exp"), args[0]);
auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {exp_ins});
return prog.add_instruction(make_op("log"), sum_ins);
return mm->add_instruction(make_op("log"), sum_ins);
}
instruction_ref
parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
instruction_ref parse_reduce_sum_square(const std::string&,
node_info info,
std::vector<instruction_ref> args) const
{
auto square_ins = prog.add_instruction(make_op("mul"), args[0], args[0]);
auto square_ins = mm->add_instruction(make_op("mul"), args[0], args[0]);
return parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
}
instruction_ref
parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args) const
{
if(!contains(info.attributes, "to"))
{
......@@ -2399,7 +2404,7 @@ struct onnx_parser
int to_type = parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
}
std::vector<instruction_ref>
......@@ -2448,7 +2453,7 @@ struct onnx_parser
for(auto sl : vec_splits)
{
ret_ins.push_back(
prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
mm->add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
start += sl;
}
......@@ -2476,8 +2481,8 @@ struct onnx_parser
auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}};
auto l_val = prog.add_literal({s, depth_input});
auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]});
auto l_val = mm->add_literal({s, depth_input});
auto gather_out = mm->add_instruction(op::gather{0}, {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size();
......@@ -2489,16 +2494,16 @@ struct onnx_parser
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out);
auto tr_out = mm->add_instruction(op::transpose{perm}, gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto diff = prog.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = prog.add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
auto l_mul = prog.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return prog.add_instruction(make_op("add"), l_mul, unsq_off_val);
auto off_val = mm->add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = mm->add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto diff = mm->add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = mm->add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = mm->add_instruction(op::multibroadcast{lens}, diff);
auto l_mul = mm->add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return mm->add_instruction(make_op("add"), l_mul, unsq_off_val);
}
instruction_ref
......@@ -2515,7 +2520,7 @@ struct onnx_parser
auto l1 = l0;
for(int j = 1; j < repeats[i]; j++)
{
l0 = prog.add_instruction(op::concat{i}, l0, l1);
l0 = mm->add_instruction(op::concat{i}, l0, l1);
}
}
return l0;
......@@ -2557,7 +2562,7 @@ struct onnx_parser
return result;
});
l0 = prog.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
l0 = mm->add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
});
return l0;
}
......@@ -2569,7 +2574,8 @@ struct onnx_parser
max = 2
};
instruction_ref parse_embedding_bag(const node_info& info, std::vector<instruction_ref> args)
instruction_ref parse_embedding_bag(const node_info& info,
std::vector<instruction_ref> args) const
{
if(args[2]->get_shape().elements() != 1)
MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1");
......@@ -2579,24 +2585,24 @@ struct onnx_parser
reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
}
auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]);
auto l0 = mm->add_instruction(op::gather{}, args[0], args[1]);
switch(reduce_mode)
{
case reduce_mode_t::sum:
l0 = prog.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
l0 = mm->add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::mean:
l0 = prog.add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
l0 = mm->add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::max:
l0 = prog.add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
l0 = mm->add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
break;
}
return l0;
}
instruction_ref
parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args)
parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "operator"))
{
......@@ -2610,13 +2616,13 @@ struct onnx_parser
}
std::vector<instruction_ref>
parse_dropout(const std::string&, const node_info&, std::vector<instruction_ref> args)
parse_dropout(const std::string&, const node_info&, std::vector<instruction_ref> args) const
{
auto out = prog.add_instruction(make_op("identity"), args[0]);
auto out = mm->add_instruction(make_op("identity"), args[0]);
auto s = args[0]->get_shape();
std::vector<int8_t> vec(s.elements(), 1);
shape mask_s{shape::bool_type, s.lens()};
auto mask = prog.add_literal(literal(mask_s, vec));
auto mask = mm->add_literal(literal(mask_s, vec));
return {out, mask};
}
......@@ -2661,7 +2667,7 @@ struct onnx_parser
}
}
return prog.add_literal(literal(out_s, out_data));
return mm->add_literal(literal(out_s, out_data));
}
instruction_ref parse_compare_op(const std::string&,
......@@ -2672,7 +2678,7 @@ struct onnx_parser
auto l = add_broadcastable_binary_op(args[0], args[1], op_name);
if(l->get_shape().type() != shape::bool_type)
{
l = prog.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
l = mm->add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
}
return l;
}
......@@ -2737,9 +2743,9 @@ struct onnx_parser
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = prog.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = prog.add_literal(literal(ind_s, ind));
return prog.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
auto rsp = mm->add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = mm->add_literal(literal(ind_s, ind));
return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
instruction_ref
......@@ -2748,7 +2754,7 @@ struct onnx_parser
auto type = args[1]->get_shape().type();
// the operation of if cond == 1 select x; else select y,
// is equivalent to cond * (x - y) + y
auto cond = prog.add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
auto cond = mm->add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
auto diff = add_broadcastable_binary_op(args[1], args[2], "sub");
auto cd = add_broadcastable_binary_op(diff, cond, "mul");
return add_broadcastable_binary_op(cd, args[2], "add");
......@@ -2795,7 +2801,7 @@ struct onnx_parser
{
for(auto&& f : graph.initializer())
{
instructions[f.name()] = prog.add_literal(parse_tensor(f));
instructions[f.name()] = mm->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
......@@ -2811,7 +2817,7 @@ struct onnx_parser
}
shape s = parse_type(input.type(), dims);
instructions[name] = prog.add_parameter(name, s);
instructions[name] = mm->add_parameter(name, s);
}
}
......@@ -2837,7 +2843,7 @@ struct onnx_parser
if(ops.count(node.op_type()) == 0)
{
if(skip_unknown_operators)
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
result.push_back(mm->add_instruction(op::unknown{node.op_type()}, args));
else
MIGRAPHX_THROW("Unknown operator: " + node.op_type());
}
......@@ -2875,14 +2881,14 @@ struct onnx_parser
[&](const auto& name) { return instructions[name]; });
// add the return instuction
prog.add_return(output_ins);
mm->add_return(output_ins);
}
void parse_undefined(const std::string& name)
{
if(!contains(instructions, name))
{
auto ins = prog.add_instruction(op::undefined{});
auto ins = mm->add_instruction(op::undefined{});
instructions[name] = ins;
}
}
......
......@@ -4,7 +4,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const
void memory_coloring::apply(module& p) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
......
......@@ -67,7 +67,7 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify)
memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{
instr2_live.clear();
......@@ -145,7 +145,7 @@ struct memory_coloring_impl
return (i1->offset > i2->offset);
}
};
program* p_program;
module* p_program;
std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals.
std::vector<live_interval> live_intervals;
......
......@@ -15,20 +15,20 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
p.apply(modl);
trace(modl);
#ifndef NDEBUG
trace("Validate ...");
auto invalid = prog.validate();
if(invalid != prog.end())
auto invalid = modl.validate();
if(invalid != modl.end())
{
auto index = std::distance(prog.begin(), invalid);
auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
......
......@@ -20,7 +20,7 @@ bool skip_propogate(instruction_ref ins)
return false;
}
void propagate_constant::apply(program& p) const
void propagate_constant::apply(module& p) const
{
for(auto i : iterator_for(p))
{
......
......@@ -256,8 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; });
py::class_<migraphx::module>(m, "module").def("print", [](const migraphx::module& mm) {
std::cout << mm << std::endl;
});
py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
......@@ -277,12 +278,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto mm = p.get_main_module();
return migraphx::module_wrap{mm};
auto* mm = p.get_main_module();
return migraphx::module{*mm};
})
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm;
migraphx::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
......@@ -389,7 +390,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
&migraphx::quantize_int8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
......
......@@ -27,7 +27,7 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog,
instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
......@@ -59,11 +59,11 @@ instruction_ref insert_quant_ins(program& prog,
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins =
prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
......@@ -72,27 +72,27 @@ instruction_ref insert_quant_ins(program& prog,
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(
float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = prog.add_literal(127.0f);
auto min_clip = prog.add_literal(-128.0f);
max_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
}
else
{
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins);
}
map_ins[ins] = quant_ins;
......@@ -107,8 +107,9 @@ instruction_ref insert_quant_ins(program& prog,
// truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
auto* mm = prog.get_main_module();
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
......@@ -139,7 +140,7 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
}
else
{
input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16);
input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
......@@ -162,18 +163,18 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
mm->replace_instruction(ins, ins_orig_type);
}
}
prog.replace_instruction(ins, op, converted_inputs);
mm->replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(program& prog,
static void ins_quantize_int8(module& modl,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
......@@ -195,14 +196,14 @@ static void ins_quantize_int8(program& prog,
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
prog.replace_instruction(
modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
}
else
{
auto quant_dot = prog.insert_instruction(
auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot);
modl.replace_instruction(ins, op::convert{orig_type}, quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
......@@ -213,51 +214,51 @@ static void ins_quantize_int8(program& prog,
{
converted_inputs.pop_back();
}
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c =
prog.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
}
else
{
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
}
else
{
auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
prog.replace_instruction(ins, op::convert{orig_type}, f_res);
auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, op::convert{orig_type}, f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
}
else
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
}
}
}
......@@ -274,7 +275,7 @@ static void ins_quantize_int8(program& prog,
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = prog.insert_instruction(
auto quant_conv = modl.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
......@@ -282,25 +283,25 @@ static void ins_quantize_int8(program& prog,
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = prog.add_literal(
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv =
prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor));
modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, op::mul{}, l_factor, float_conv);
}
else
{
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
}
}
}
......@@ -338,10 +339,11 @@ void quantize_int8_impl(program& prog,
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
......@@ -398,7 +400,7 @@ void quantize_int8_impl(program& prog,
else
{
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
*mm, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
......@@ -414,7 +416,7 @@ void quantize_int8_impl(program& prog,
continue;
}
ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params);
ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
......@@ -425,7 +427,7 @@ void quantize_int8_impl(program& prog,
void quantize_int8(program& prog,
const target& t,
const std::vector<program::parameter_map>& calibration,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
// insert capture operator
......@@ -439,7 +441,7 @@ void quantize_int8(program& prog,
// quantization scale and shift
for(auto&& arg : calibration)
{
program::parameter_map m;
parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
......@@ -464,7 +466,7 @@ std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
auto* mm = prog.get_main_module();
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
......@@ -476,7 +478,7 @@ std::size_t capture_arguments(program& prog,
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(*mm))
{
if(not contains(ins_names, ins->name()))
{
......@@ -494,7 +496,7 @@ std::size_t capture_arguments(program& prog,
}
else
{
new_ins = prog.insert_instruction(
new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
......
......@@ -22,7 +22,7 @@ struct find_dot_add
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
......@@ -36,7 +36,7 @@ struct find_dot_add
};
} // namespace
void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); }
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -12,7 +12,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(program& p) const
void rewrite_batchnorm::apply(module& p) const
{
for(auto ins : iterator_for(p))
{
......
......@@ -10,7 +10,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
void rewrite_pooling::apply(module& prog) const
{
for(auto ins : iterator_for(prog))
{
......
......@@ -28,7 +28,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
void rewrite_rnn::apply(module& prog) const
{
for(auto ins : iterator_for(prog))
{
......@@ -47,13 +47,13 @@ void rewrite_rnn::apply(program& prog) const
}
}
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
// when writing their module.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
......@@ -210,7 +210,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const
......@@ -336,7 +336,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
}
}
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
......@@ -502,7 +502,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -685,7 +685,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
}
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
......@@ -927,7 +927,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -1158,7 +1158,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
{
bool is_var_lens = false;
if(seq_lens != prog.end())
......@@ -1188,7 +1188,7 @@ bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_
}
std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
{
bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
auto input_shape = input->get_shape();
......@@ -1204,7 +1204,7 @@ rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction
return length;
}
instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
......@@ -1243,7 +1243,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
return result_ins;
}
void rewrite_rnn::replace_last_cell_output(program& prog,
void rewrite_rnn::replace_last_cell_output(module& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
......@@ -1281,7 +1281,7 @@ void rewrite_rnn::replace_last_cell_output(program& prog,
}
}
instruction_ref rewrite_rnn::pad_hidden_states(program& prog,
instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const
......
......@@ -103,7 +103,7 @@ struct stream_info
}
};
std::size_t assign_streams(program& p, std::size_t n)
std::size_t assign_streams(module& p, std::size_t n)
{
assert(n > 0);
partition critical;
......@@ -182,7 +182,7 @@ struct stream_info
}
};
void sort(program& p, std::size_t) const
void sort(module& p, std::size_t)
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
......@@ -335,7 +335,7 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(program& p) const
find_concurrent_instructions(module& p) const
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
......@@ -378,7 +378,7 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p)
get_conflicts(module& p)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
......@@ -464,7 +464,7 @@ struct stream_info
}
};
void schedule::apply(program& p) const
void schedule::apply(module& p) const
{
if(not enable)
return;
......
......@@ -50,7 +50,7 @@ struct find_mul_conv
match::name("broadcast").bind("a")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
......@@ -86,7 +86,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto slice_ins = r.instructions["slice"];
......@@ -169,7 +169,7 @@ struct find_mul_add
match::is_constant().bind("a")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
......@@ -191,7 +191,7 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -211,7 +211,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -249,7 +249,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -294,7 +294,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise");
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
......@@ -425,7 +425,7 @@ struct find_splits
return groups;
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -520,7 +520,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -618,7 +618,7 @@ struct find_add_convs
input.strides()[3] * n}};
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_conv = r.instructions["a"];
......@@ -689,7 +689,7 @@ struct find_conv_dot_horiz_fusion
{
auto matcher() const { return horiz_conv_dot(); }
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -762,7 +762,7 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
......@@ -782,7 +782,7 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
......@@ -803,7 +803,7 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -828,7 +828,7 @@ struct find_split_reshape
.bind("reshape");
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
......@@ -904,7 +904,7 @@ struct find_split_transpose
.bind("trans");
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"];
......@@ -949,7 +949,7 @@ struct find_split_transpose
}
};
void simplify_algebra::apply(program& p) const
void simplify_algebra::apply(module& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 8; i++)
......
......@@ -66,7 +66,7 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
......@@ -113,7 +113,7 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
......@@ -128,7 +128,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins;
......@@ -201,7 +201,7 @@ struct find_nested_slice
return result;
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto slice = ins->inputs().front();
......@@ -230,7 +230,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto trans_inputs = ins->inputs();
......@@ -279,7 +279,7 @@ struct find_nested_concat
return op.axis;
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto axis = get_axis(ins);
......@@ -298,7 +298,7 @@ struct find_nested_concat
}
};
void simplify_reshapes::apply(program& p) const
void simplify_reshapes::apply(module& p) const
{
for(int i = 0; i < 2; i++)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment