Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
......@@ -10,6 +10,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Schedule instructions for concurrent execution
......@@ -19,7 +20,7 @@ struct schedule
schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,6 +16,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct operation;
#ifdef DOXYGEN
......@@ -26,11 +27,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(program& p, instruction_ref ins, std::size_t n) const;
void sched(module& p, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
......@@ -43,9 +44,9 @@ struct schedule_model
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const;
* void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
......@@ -120,19 +121,19 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency();
}
void sched(program& p, instruction_ref ins, std::size_t n) const
void sched(module& p, instruction_ref ins, std::size_t n) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n);
}
void wait(program& p, instruction_ref ins, std::size_t wait_id) const
void wait(module& p, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id);
}
void record(program& p, instruction_ref ins, std::size_t wait_id) const
void record(module& p, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id);
......@@ -158,11 +159,11 @@ struct schedule_model
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -195,19 +196,19 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(program& p, instruction_ref ins, std::size_t n) const override
void sched(module& p, instruction_ref ins, std::size_t n) const override
{
private_detail_te_value.sched(p, ins, n);
}
void wait(program& p, instruction_ref ins, std::size_t wait_id) const override
void wait(module& p, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.wait(p, ins, wait_id);
}
void record(program& p, instruction_ref ins, std::size_t wait_id) const override
void record(module& p, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.record(p, ins, wait_id);
......
......@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Simplify many algebraic instructions to more efficient versions.
......@@ -15,7 +16,7 @@ struct program;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Eliminate redundant reshapes.
......@@ -16,7 +17,7 @@ struct program;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -61,7 +61,7 @@ int main(int argc, char const* argv[])
{
// GPU target
prog.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
migraphx::parameter_map m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes())
{
......
......@@ -124,7 +124,7 @@ int main(int argc, char const* argv[])
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl;
auto* ptr = input.data();
migraphx::program::parameter_map m;
migraphx::parameter_map m;
m["output"] =
migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++)
......
This diff is collapsed.
......@@ -4,7 +4,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const
void memory_coloring::apply(module& p) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
......
......@@ -67,7 +67,7 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify)
memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{
instr2_live.clear();
......@@ -145,7 +145,7 @@ struct memory_coloring_impl
return (i1->offset > i2->offset);
}
};
program* p_program;
module* p_program;
std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals.
std::vector<live_interval> live_intervals;
......
......@@ -15,20 +15,20 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
p.apply(modl);
trace(modl);
#ifndef NDEBUG
trace("Validate ...");
auto invalid = prog.validate();
if(invalid != prog.end())
auto invalid = modl.validate();
if(invalid != modl.end())
{
auto index = std::distance(prog.begin(), invalid);
auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
......
......@@ -20,7 +20,7 @@ bool skip_propogate(instruction_ref ins)
return false;
}
void propagate_constant::apply(program& p) const
void propagate_constant::apply(module& p) const
{
for(auto i : iterator_for(p))
{
......
......@@ -256,8 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; });
py::class_<migraphx::module>(m, "module").def("print", [](const migraphx::module& mm) {
std::cout << mm << std::endl;
});
py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
......@@ -277,12 +278,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto mm = p.get_main_module();
return migraphx::module_wrap{mm};
auto* mm = p.get_main_module();
return migraphx::module{*mm};
})
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm;
migraphx::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
......@@ -389,7 +390,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
&migraphx::quantize_int8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
......
......@@ -27,7 +27,7 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog,
instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
......@@ -59,11 +59,11 @@ instruction_ref insert_quant_ins(program& prog,
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins =
prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
......@@ -72,27 +72,27 @@ instruction_ref insert_quant_ins(program& prog,
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(
float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = prog.add_literal(127.0f);
auto min_clip = prog.add_literal(-128.0f);
max_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
}
else
{
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins);
}
map_ins[ins] = quant_ins;
......@@ -107,8 +107,9 @@ instruction_ref insert_quant_ins(program& prog,
// truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
auto* mm = prog.get_main_module();
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
......@@ -139,7 +140,7 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
}
else
{
input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16);
input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
......@@ -162,18 +163,18 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
mm->replace_instruction(ins, ins_orig_type);
}
}
prog.replace_instruction(ins, op, converted_inputs);
mm->replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(program& prog,
static void ins_quantize_int8(module& modl,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
......@@ -195,14 +196,14 @@ static void ins_quantize_int8(program& prog,
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
prog.replace_instruction(
modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
}
else
{
auto quant_dot = prog.insert_instruction(
auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot);
modl.replace_instruction(ins, op::convert{orig_type}, quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
......@@ -213,51 +214,51 @@ static void ins_quantize_int8(program& prog,
{
converted_inputs.pop_back();
}
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c =
prog.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
}
else
{
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
}
else
{
auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
prog.replace_instruction(ins, op::convert{orig_type}, f_res);
auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, op::convert{orig_type}, f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
}
else
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
}
}
}
......@@ -274,7 +275,7 @@ static void ins_quantize_int8(program& prog,
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = prog.insert_instruction(
auto quant_conv = modl.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
......@@ -282,25 +283,25 @@ static void ins_quantize_int8(program& prog,
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = prog.add_literal(
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv =
prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor));
modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, op::mul{}, l_factor, float_conv);
}
else
{
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
}
}
}
......@@ -338,10 +339,11 @@ void quantize_int8_impl(program& prog,
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
......@@ -398,7 +400,7 @@ void quantize_int8_impl(program& prog,
else
{
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
*mm, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
......@@ -414,7 +416,7 @@ void quantize_int8_impl(program& prog,
continue;
}
ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params);
ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
......@@ -425,7 +427,7 @@ void quantize_int8_impl(program& prog,
void quantize_int8(program& prog,
const target& t,
const std::vector<program::parameter_map>& calibration,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
// insert capture operator
......@@ -439,7 +441,7 @@ void quantize_int8(program& prog,
// quantization scale and shift
for(auto&& arg : calibration)
{
program::parameter_map m;
parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
......@@ -464,7 +466,7 @@ std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
auto* mm = prog.get_main_module();
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
......@@ -476,7 +478,7 @@ std::size_t capture_arguments(program& prog,
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(*mm))
{
if(not contains(ins_names, ins->name()))
{
......@@ -494,7 +496,7 @@ std::size_t capture_arguments(program& prog,
}
else
{
new_ins = prog.insert_instruction(
new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
......
......@@ -22,7 +22,7 @@ struct find_dot_add
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
......@@ -36,7 +36,7 @@ struct find_dot_add
};
} // namespace
void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); }
void remap::apply(module& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -12,7 +12,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(program& p) const
void rewrite_batchnorm::apply(module& p) const
{
for(auto ins : iterator_for(p))
{
......
......@@ -10,7 +10,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
void rewrite_pooling::apply(module& prog) const
{
for(auto ins : iterator_for(prog))
{
......
......@@ -28,7 +28,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
void rewrite_rnn::apply(module& prog) const
{
for(auto ins : iterator_for(prog))
{
......@@ -47,13 +47,13 @@ void rewrite_rnn::apply(program& prog) const
}
}
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
// when writing their module.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
......@@ -210,7 +210,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const
......@@ -336,7 +336,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
}
}
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
......@@ -502,7 +502,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -685,7 +685,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
}
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
......@@ -927,7 +927,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -1158,7 +1158,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
{
bool is_var_lens = false;
if(seq_lens != prog.end())
......@@ -1188,7 +1188,7 @@ bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_
}
std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
{
bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
auto input_shape = input->get_shape();
......@@ -1204,7 +1204,7 @@ rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction
return length;
}
instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
......@@ -1243,7 +1243,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
return result_ins;
}
void rewrite_rnn::replace_last_cell_output(program& prog,
void rewrite_rnn::replace_last_cell_output(module& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
......@@ -1281,7 +1281,7 @@ void rewrite_rnn::replace_last_cell_output(program& prog,
}
}
instruction_ref rewrite_rnn::pad_hidden_states(program& prog,
instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const
......
......@@ -103,7 +103,7 @@ struct stream_info
}
};
std::size_t assign_streams(program& p, std::size_t n)
std::size_t assign_streams(module& p, std::size_t n)
{
assert(n > 0);
partition critical;
......@@ -182,7 +182,7 @@ struct stream_info
}
};
void sort(program& p, std::size_t) const
void sort(module& p, std::size_t)
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
......@@ -335,7 +335,7 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(program& p) const
find_concurrent_instructions(module& p) const
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
......@@ -378,7 +378,7 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p)
get_conflicts(module& p)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
......@@ -464,7 +464,7 @@ struct stream_info
}
};
void schedule::apply(program& p) const
void schedule::apply(module& p) const
{
if(not enable)
return;
......
......@@ -50,7 +50,7 @@ struct find_mul_conv
match::name("broadcast").bind("a")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
......@@ -86,7 +86,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto slice_ins = r.instructions["slice"];
......@@ -169,7 +169,7 @@ struct find_mul_add
match::is_constant().bind("a")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
......@@ -191,7 +191,7 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -211,7 +211,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -249,7 +249,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -294,7 +294,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise");
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
......@@ -425,7 +425,7 @@ struct find_splits
return groups;
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -520,7 +520,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -618,7 +618,7 @@ struct find_add_convs
input.strides()[3] * n}};
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_conv = r.instructions["a"];
......@@ -689,7 +689,7 @@ struct find_conv_dot_horiz_fusion
{
auto matcher() const { return horiz_conv_dot(); }
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -762,7 +762,7 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
......@@ -782,7 +782,7 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
......@@ -803,7 +803,7 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -828,7 +828,7 @@ struct find_split_reshape
.bind("reshape");
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
......@@ -904,7 +904,7 @@ struct find_split_transpose
.bind("trans");
}
void apply(program& p, match::matcher_result r) const
void apply(module& p, match::matcher_result r) const
{
auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"];
......@@ -949,7 +949,7 @@ struct find_split_transpose
}
};
void simplify_algebra::apply(program& p) const
void simplify_algebra::apply(module& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 8; i++)
......
......@@ -66,7 +66,7 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
......@@ -113,7 +113,7 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
......@@ -128,7 +128,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins;
......@@ -201,7 +201,7 @@ struct find_nested_slice
return result;
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto slice = ins->inputs().front();
......@@ -230,7 +230,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto trans_inputs = ins->inputs();
......@@ -279,7 +279,7 @@ struct find_nested_concat
return op.axis;
}
void apply(program& p, const match::matcher_result& mr) const
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto axis = get_axis(ins);
......@@ -298,7 +298,7 @@ struct find_nested_concat
}
};
void simplify_reshapes::apply(program& p) const
void simplify_reshapes::apply(module& p) const
{
for(int i = 0; i < 2; i++)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment