Commit d2af0a8c authored by Paul's avatar Paul
Browse files

Rename arguments to use inputs function

parent 872ff12a
......@@ -26,7 +26,7 @@ void dead_code_elimination::apply(program& p) const
assert(p.has_instruction(leaf));
if(leaf->output.empty())
{
auto args = leaf->arguments;
auto args = leaf->inputs();
leaf->clear_arguments();
p.move_instruction(leaf, p.end());
for(auto arg : args)
......
......@@ -27,15 +27,15 @@ void eliminate_contiguous::apply(program& p) const
for(auto ins : iterator_for(p))
{
// Make a copy so we can modify it while we iterate
auto args = ins->arguments;
for(auto arg : ins->arguments)
auto args = ins->inputs();
for(auto arg : ins->inputs())
{
// TODO: Pass in names for the operator in the constructor instead
// of using ends_with
if(ends_with(arg->name(), "contiguous"))
{
auto new_args = args;
auto prev = arg->arguments.front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args))
{
......
......@@ -12,26 +12,26 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
if(ins->name() != "batch_norm_inference")
continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) {
if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
continue;
auto conv_ins = ins->arguments[0];
auto conv_ins = ins->inputs()[0];
if(conv_ins->name() != "convolution")
continue;
if(conv_ins->arguments[1]->name() != "@literal")
if(conv_ins->inputs()[1]->name() != "@literal")
continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal();
const auto& bias = ins->arguments[2]->get_literal();
const auto& mean = ins->arguments[3]->get_literal();
const auto& variance = ins->arguments[4]->get_literal();
const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op);
auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->arguments[1]->get_literal();
const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->op;
auto weights_lens = weights.get_shape().lens();
......@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b});
}
......
......@@ -87,8 +87,8 @@ struct instruction
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
i->arguments.end();
return std::find(i->inputs().begin(), i->inputs().end(), *this) !=
i->inputs().end();
});
}
......@@ -156,7 +156,7 @@ struct instruction
inline void backreference(instruction_ref ref)
{
for(auto&& arg : ref->arguments)
for(auto&& arg : ref->inputs())
arg->add_output(ref);
}
......
......@@ -68,7 +68,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
auto create_program = [&] {
migraph::program p;
std::vector<migraph::instruction_ref> inputs;
for(auto&& arg : ins.arguments)
for(auto&& arg : ins.inputs())
{
if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->lit));
......
......@@ -48,10 +48,10 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os << "{" << ins->lit << "}";
}
if(!ins->arguments.empty())
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->arguments)
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
os << delim << names.at(arg);
......@@ -93,7 +93,7 @@ instruction_ref program::insert_instruction(instruction_ref ins,
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
backreference(result);
// assert(result->arguments == args);
// assert(result->inputs() == args);
assert(result->valid(begin()));
return result;
}
......@@ -300,9 +300,9 @@ argument generic_eval(const program& p,
}
else
{
values.resize(ins->arguments.size());
std::transform(ins->arguments.begin(),
ins->arguments.end(),
values.resize(ins->inputs().size());
std::transform(ins->inputs().begin(),
ins->inputs().end(),
values.begin(),
[&](instruction_ref i) {
assert(results.find(i) != results.end());
......
......@@ -35,9 +35,9 @@ void simplify_reshapes::apply(program& p) const
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name()))
{
assert(!reshapes.back()->arguments.empty());
assert(p.has_instruction(reshapes.back()->arguments.front()));
reshapes.push_back(reshapes.back()->arguments.front());
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front());
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
......
......@@ -597,30 +597,30 @@ struct cpu_apply
template <class T>
void apply_simple_op(instruction_ref ins)
{
prog->replace_instruction(ins, T{}, ins->arguments);
prog->replace_instruction(ins, T{}, ins->inputs());
}
template <class T, class Op>
void apply_extend_op(instruction_ref ins)
{
auto&& op = any_cast<Op>(ins->op);
prog->replace_instruction(ins, T{op}, ins->arguments);
prog->replace_instruction(ins, T{op}, ins->inputs());
}
void apply_activation(instruction_ref ins)
{
auto&& op = any_cast<activation>(ins->op);
if(op.mode == "relu")
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->arguments);
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->inputs());
}
void apply_pooling(instruction_ref ins)
{
auto&& op = any_cast<pooling>(ins->op);
if(op.mode == "max")
prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->arguments);
prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average")
prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->arguments);
prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->inputs());
}
};
......
......@@ -28,12 +28,12 @@ void fuse_ops::apply(program& p) const
{
if(ins->name() != "gpu::relu")
continue;
auto add_ins = ins->arguments.front();
auto add_ins = ins->inputs().front();
if(add_ins->name() != "gpu::add")
continue;
auto args = add_ins->arguments;
auto args = add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->arguments.back();
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_relu{}, args);
}
}
......
......@@ -372,13 +372,13 @@ struct miopen_apply
auto&& op = any_cast<convolution>(ins->op);
auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), ins->arguments);
auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs());
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, conv, ins->arguments.at(0), ins->arguments.at(1), workspace, output);
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
}
instruction_ref apply_pooling(instruction_ref ins)
......@@ -388,7 +388,7 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
ins, miopen_pooling{op, std::move(pd)}, ins->inputs().at(0), output);
}
instruction_ref apply_activation(instruction_ref ins)
......@@ -399,7 +399,7 @@ struct miopen_apply
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
ins, miopen_relu{std::move(ad)}, ins->inputs().at(0), output);
}
return ins;
}
......@@ -408,7 +408,7 @@ struct miopen_apply
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, hip_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
ins, hip_add{}, ins->inputs().at(0), ins->inputs().at(1), output);
}
instruction_ref apply_gemm(instruction_ref ins)
......@@ -416,31 +416,31 @@ struct miopen_apply
auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
}
instruction_ref apply_contiguous(instruction_ref ins)
{
auto&& op = any_cast<contiguous>(ins->op);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->arguments.at(0), output);
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
}
instruction_ref apply_batch_norm_inference(instruction_ref ins)
{
auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->arguments.at(1)->get_shape();
shape old_shape = ins->inputs().at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto reshape_op = reshape{new_shape};
std::vector<instruction_ref> reshapes;
std::transform(ins->arguments.begin() + 1,
ins->arguments.end(),
std::transform(ins->inputs().begin() + 1,
ins->inputs().end(),
std::back_inserter(reshapes),
[&](auto i) { return prog->insert_instruction(ins, reshape_op, i); });
return prog->replace_instruction(ins,
miopen_batch_norm_inference{op},
ins->arguments.at(0),
ins->inputs().at(0),
reshapes[0],
reshapes[1],
reshapes[2],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment