Commit 4566709b authored by Paul's avatar Paul
Browse files

Renamed to use name function directly

parent 25f560c3
...@@ -17,7 +17,7 @@ void dead_code_elimination::apply(program& p) const ...@@ -17,7 +17,7 @@ void dead_code_elimination::apply(program& p) const
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip instruction with empty shape as output unless its a builtin // Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0 and not(i->op.name().front() == '@')) if(i->result.elements() == 0 and not(i->name().front() == '@'))
continue; continue;
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
......
...@@ -14,7 +14,7 @@ void eliminate_allocation::apply(program& p) const ...@@ -14,7 +14,7 @@ void eliminate_allocation::apply(program& p) const
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
allocs.emplace_back(ins, n); allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes(); std::size_t size = ins->get_shape().bytes();
......
...@@ -32,7 +32,7 @@ void eliminate_contiguous::apply(program& p) const ...@@ -32,7 +32,7 @@ void eliminate_contiguous::apply(program& p) const
{ {
// TODO: Pass in names for the operator in the constructor instead // TODO: Pass in names for the operator in the constructor instead
// of using ends_with // of using ends_with
if(ends_with(arg->op.name(), "contiguous")) if(ends_with(arg->name(), "contiguous"))
{ {
auto new_args = args; auto new_args = args;
auto prev = arg->arguments.front(); auto prev = arg->arguments.front();
......
...@@ -10,17 +10,17 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -10,17 +10,17 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != "batch_norm_inference") if(ins->name() != "batch_norm_inference")
continue; continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) { if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) {
return arg->op.name() == "@literal"; return arg->name() == "@literal";
})) }))
continue; continue;
auto conv_ins = ins->arguments[0]; auto conv_ins = ins->arguments[0];
if(conv_ins->op.name() != "convolution") if(conv_ins->name() != "convolution")
continue; continue;
if(conv_ins->arguments[1]->op.name() != "@literal") if(conv_ins->arguments[1]->name() != "@literal")
continue; continue;
// Get scale, bias, mean, variance from instruction_ref // Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal(); const auto& gamma = ins->arguments[1]->get_literal();
......
...@@ -32,7 +32,7 @@ struct instruction ...@@ -32,7 +32,7 @@ struct instruction
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
assert(ins->op.name().front() != '@'); assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape();
} }
} }
......
...@@ -57,20 +57,20 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80) ...@@ -57,20 +57,20 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
{ {
for(auto&& ins : prog) for(auto&& ins : prog)
{ {
if(ins.op.name().front() == '@') if(ins.name().front() == '@')
continue; continue;
if(ins.op.name() == "broadcast") if(ins.name() == "broadcast")
continue; continue;
if(ins.op.name() == "transpose") if(ins.name() == "transpose")
continue; continue;
if(ins.op.name() == "reshape") if(ins.name() == "reshape")
continue; continue;
auto create_program = [&] { auto create_program = [&] {
migraph::program p; migraph::program p;
std::vector<migraph::instruction_ref> inputs; std::vector<migraph::instruction_ref> inputs;
for(auto&& arg : ins.arguments) for(auto&& arg : ins.arguments)
{ {
if(arg->op.name() == "@literal") if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->lit)); inputs.push_back(p.add_literal(arg->lit));
else else
inputs.push_back( inputs.push_back(
...@@ -81,13 +81,13 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80) ...@@ -81,13 +81,13 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
}; };
try try
{ {
std::cout << "Verify: " << ins.op.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << create_program() << std::endl; std::cout << create_program() << std::endl;
verify_program(ins.op.name(), create_program, tolerance); verify_program(ins.name(), create_program, tolerance);
} }
catch(...) catch(...)
{ {
std::cout << "Instruction " << ins.op.name() << " threw an exception." << std::endl; std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw; throw;
} }
} }
......
...@@ -31,7 +31,7 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -31,7 +31,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(ins->op.name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->op).parameter; var_name = any_cast<builtin::param>(ins->op).parameter;
} }
...@@ -40,7 +40,7 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -40,7 +40,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os << ins->op; os << ins->op;
if(ins->op.name() == "@literal") if(ins->name() == "@literal")
{ {
if(ins->lit.get_shape().elements() > 10) if(ins->lit.get_shape().elements() > 10)
os << "{ ... }"; os << "{ ... }";
...@@ -188,7 +188,7 @@ shape program::get_parameter_shape(std::string name) const ...@@ -188,7 +188,7 @@ shape program::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.op.name() == "@param") if(x.name() == "@param")
{ {
return any_cast<builtin::param>(x.op).parameter == name; return any_cast<builtin::param>(x.op).parameter == name;
} }
...@@ -208,7 +208,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const ...@@ -208,7 +208,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
std::unordered_map<std::string, shape> result; std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions) for(auto&& ins : impl->instructions)
{ {
if(ins.op.name() == "@param") if(ins.name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins.op).parameter; auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result; result[name] = ins.result;
...@@ -258,7 +258,7 @@ void program::compile(const target& t, tracer trace) ...@@ -258,7 +258,7 @@ void program::compile(const target& t, tracer trace)
{ {
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->op.name()); std::to_string(index) + ": " + invalid->name());
} }
trace(); trace();
#endif #endif
...@@ -284,17 +284,17 @@ argument generic_eval(const program& p, ...@@ -284,17 +284,17 @@ argument generic_eval(const program& p,
values.reserve(16); values.reserve(16);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() == "@literal") if(ins->name() == "@literal")
{ {
results.emplace(ins, trace(ins, [&] { return ins->lit.get_argument(); })); results.emplace(ins, trace(ins, [&] { return ins->lit.get_argument(); }));
} }
else if(ins->op.name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
return params.at(any_cast<builtin::param>(ins->op).parameter); return params.at(any_cast<builtin::param>(ins->op).parameter);
})); }));
} }
else if(ins->op.name() == "@outline") else if(ins->name() == "@outline")
{ {
results.emplace(ins, trace(ins, [&] { return argument{ins->result, nullptr}; })); results.emplace(ins, trace(ins, [&] { return argument{ins->result, nullptr}; }));
} }
...@@ -385,7 +385,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -385,7 +385,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[p.first->op.name()] += avg; op_times[p.first->name()] += avg;
total_instruction_time += avg; total_instruction_time += avg;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
......
...@@ -25,15 +25,15 @@ void simplify_reshapes::apply(program& p) const ...@@ -25,15 +25,15 @@ void simplify_reshapes::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins->op.name())) if(not is_reshaper(ins->name()))
continue; continue;
if(ins->output.size() != 1) if(ins->output.size() != 1)
continue; continue;
if(is_reshaper(ins->output.front()->op.name())) if(is_reshaper(ins->output.front()->name()))
continue; continue;
// Gather reshapes // Gather reshapes
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->op.name())) while(is_reshaper(reshapes.back()->name()))
{ {
assert(!reshapes.back()->arguments.empty()); assert(!reshapes.back()->arguments.empty());
assert(p.has_instruction(reshapes.back()->arguments.front())); assert(p.has_instruction(reshapes.back()->arguments.front()));
......
...@@ -579,17 +579,17 @@ struct cpu_apply ...@@ -579,17 +579,17 @@ struct cpu_apply
init(); init();
for(auto it : iterator_for(*prog)) for(auto it : iterator_for(*prog))
{ {
if(it->op.name() == "activation") if(it->name() == "activation")
{ {
apply_activation(it); apply_activation(it);
} }
else if(it->op.name() == "pooling") else if(it->name() == "pooling")
{ {
apply_pooling(it); apply_pooling(it);
} }
else if(apply_map.count(it->op.name()) > 0) else if(apply_map.count(it->name()) > 0)
{ {
apply_map.at(it->op.name())(it); apply_map.at(it->name())(it);
} }
} }
} }
......
...@@ -18,7 +18,7 @@ void eliminate_workspace::apply(program& p) const ...@@ -18,7 +18,7 @@ void eliminate_workspace::apply(program& p) const
{ {
if(ins->output.size() != 1) if(ins->output.size() != 1)
continue; continue;
if(ins->op.name() != "hip::allocate") if(ins->name() != "hip::allocate")
continue; continue;
auto&& a = any_cast<hip_allocate>(ins->op); auto&& a = any_cast<hip_allocate>(ins->op);
if(a.tag == "workspace") if(a.tag == "workspace")
......
...@@ -26,10 +26,10 @@ void fuse_ops::apply(program& p) const ...@@ -26,10 +26,10 @@ void fuse_ops::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != "gpu::relu") if(ins->name() != "gpu::relu")
continue; continue;
auto add_ins = ins->arguments.front(); auto add_ins = ins->arguments.front();
if(add_ins->op.name() != "gpu::add") if(add_ins->name() != "gpu::add")
continue; continue;
auto args = add_ins->arguments; auto args = add_ins->arguments;
// Use the allocation from the relu operator // Use the allocation from the relu operator
......
...@@ -322,31 +322,31 @@ struct miopen_apply ...@@ -322,31 +322,31 @@ struct miopen_apply
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
if(it->op.name() == "convolution") if(it->name() == "convolution")
{ {
check_shape(s, apply_convolution(it)); check_shape(s, apply_convolution(it));
} }
else if(it->op.name() == "activation") else if(it->name() == "activation")
{ {
check_shape(s, apply_activation(it)); check_shape(s, apply_activation(it));
} }
else if(it->op.name() == "pooling") else if(it->name() == "pooling")
{ {
check_shape(s, apply_pooling(it)); check_shape(s, apply_pooling(it));
} }
else if(it->op.name() == "add") else if(it->name() == "add")
{ {
check_shape(s, apply_add(it)); check_shape(s, apply_add(it));
} }
else if(it->op.name() == "gemm") else if(it->name() == "gemm")
{ {
check_shape(s, apply_gemm(it)); check_shape(s, apply_gemm(it));
} }
else if(it->op.name() == "contiguous") else if(it->name() == "contiguous")
{ {
check_shape(s, apply_contiguous(it)); check_shape(s, apply_contiguous(it));
} }
else if(it->op.name() == "batch_norm_inference") else if(it->name() == "batch_norm_inference")
{ {
check_shape(s, apply_batch_norm_inference(it)); check_shape(s, apply_batch_norm_inference(it));
} }
......
...@@ -28,7 +28,7 @@ void write_literals::apply(program& p) const ...@@ -28,7 +28,7 @@ void write_literals::apply(program& p) const
assert(ctx != nullptr); assert(ctx != nullptr);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() == "@literal") if(ins->name() == "@literal")
{ {
argument a = to_gpu(ins->lit.get_argument()); argument a = to_gpu(ins->lit.get_argument());
std::size_t n = ctx->literals.size(); std::size_t n = ctx->literals.size();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment