Commit 767d2885 authored by Paul's avatar Paul
Browse files

Make fields private in instruction class

parent 50361163
......@@ -37,7 +37,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args))
if(try_compute_shape(ins->get_operator(), new_args))
{
instruction::replace_argument(ins, arg, prev);
}
......
......@@ -28,12 +28,12 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op);
auto bn_op = any_cast<batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->op;
auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
......
......@@ -138,6 +138,12 @@ struct instruction
ins->recompute_shape();
}
static void replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args)
{
ins->replace(o, r, std::move(args));
backreference(ins);
}
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
......
......@@ -76,7 +76,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.op, inputs);
p.add_instruction(ins.get_operator(), inputs);
return p;
};
try
......
......@@ -20,7 +20,7 @@ struct program_impl
context ctx;
};
const operation& get_operation(instruction_ref ins) { return ins->op; }
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
template <class F>
static void print_program(std::ostream& os, const program& p, F annonate)
......@@ -33,12 +33,12 @@ static void print_program(std::ostream& os, const program& p, F annonate)
std::string var_name = "@" + std::to_string(count);
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->op).parameter;
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
os << var_name << " = ";
os << ins->op;
os << ins->get_operator();
if(ins->name() == "@literal")
{
......@@ -108,8 +108,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
ins->replace(op, r, std::move(args));
instruction::backreference(ins);
instruction::replace(ins, op, r, std::move(args));
assert(ins->valid(begin()));
return ins;
}
......@@ -190,7 +189,7 @@ shape program::get_parameter_shape(std::string name) const
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.op).parameter == name;
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
......@@ -210,7 +209,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.op).parameter;
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape();
}
}
......@@ -291,7 +290,7 @@ argument generic_eval(const program& p,
else if(ins->name() == "@param")
{
results.emplace(ins, trace(ins, [&] {
return params.at(any_cast<builtin::param>(ins->op).parameter);
return params.at(any_cast<builtin::param>(ins->get_operator()).parameter);
}));
}
else if(ins->name() == "@outline")
......@@ -307,7 +306,7 @@ argument generic_eval(const program& p,
return results[i];
});
results.emplace(
ins, trace(ins, [&] { return ins->op.compute(ctx, ins->get_shape(), values); }));
ins, trace(ins, [&] { return ins->get_operator().compute(ctx, ins->get_shape(), values); }));
}
assert(results.find(ins) != results.end());
}
......
......@@ -603,20 +603,20 @@ struct cpu_apply
template <class T, class Op>
void apply_extend_op(instruction_ref ins)
{
auto&& op = any_cast<Op>(ins->op);
auto&& op = any_cast<Op>(ins->get_operator());
prog->replace_instruction(ins, T{op}, ins->inputs());
}
void apply_activation(instruction_ref ins)
{
auto&& op = any_cast<activation>(ins->op);
auto&& op = any_cast<activation>(ins->get_operator());
if(op.mode == "relu")
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->inputs());
}
void apply_pooling(instruction_ref ins)
{
auto&& op = any_cast<pooling>(ins->op);
auto&& op = any_cast<pooling>(ins->get_operator());
if(op.mode == "max")
prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average")
......
......@@ -20,7 +20,7 @@ void eliminate_workspace::apply(program& p) const
continue;
if(ins->name() != "hip::allocate")
continue;
auto&& a = any_cast<hip_allocate>(ins->op);
auto&& a = any_cast<hip_allocate>(ins->get_operator());
if(a.tag == "workspace")
{
n = std::max(n, ins->get_shape().bytes());
......
......@@ -369,7 +369,7 @@ struct miopen_apply
instruction_ref apply_convolution(instruction_ref ins)
{
auto&& op = any_cast<convolution>(ins->op);
auto&& op = any_cast<convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs());
......@@ -383,7 +383,7 @@ struct miopen_apply
instruction_ref apply_pooling(instruction_ref ins)
{
auto&& op = any_cast<pooling>(ins->op);
auto&& op = any_cast<pooling>(ins->get_operator());
auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->get_shape());
......@@ -393,7 +393,7 @@ struct miopen_apply
instruction_ref apply_activation(instruction_ref ins)
{
auto&& op = any_cast<activation>(ins->op);
auto&& op = any_cast<activation>(ins->get_operator());
auto ad = make_relu();
if(op.mode == "relu")
{
......@@ -413,7 +413,7 @@ struct miopen_apply
instruction_ref apply_gemm(instruction_ref ins)
{
auto&& op = any_cast<gemm>(ins->op);
auto&& op = any_cast<gemm>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
......@@ -421,14 +421,14 @@ struct miopen_apply
instruction_ref apply_contiguous(instruction_ref ins)
{
auto&& op = any_cast<contiguous>(ins->op);
auto&& op = any_cast<contiguous>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
}
instruction_ref apply_batch_norm_inference(instruction_ref ins)
{
auto&& op = any_cast<batch_norm_inference>(ins->op);
auto&& op = any_cast<batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->inputs().at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
......
......@@ -21,13 +21,13 @@ struct reverse_pass
{
for(auto ins : migraph::iterator_for(p))
{
if(ins->op.name() == "sum")
if(ins->name() == "sum")
{
p.replace_instruction(ins, minus_op{}, ins->arguments);
p.replace_instruction(ins, minus_op{}, ins->inputs());
}
else if(ins->op.name() == "minus")
else if(ins->name() == "minus")
{
p.replace_instruction(ins, sum_op{}, ins->arguments);
p.replace_instruction(ins, sum_op{}, ins->inputs());
}
}
}
......
#ifndef MIGRAPH_GUARD_ROB_HPP
#define MIGRAPH_GUARD_ROB_HPP
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
// Used to access private member variables
template <class Tag>
struct stowed
{
static typename Tag::type value;
};
template <class Tag>
typename Tag::type stowed<Tag>::value;
template <class Tag, typename Tag::type x>
struct stow_private
{
stow_private() { stowed<Tag>::value = x; }
static stow_private instance;
};
template <class Tag, typename Tag::type x>
stow_private<Tag,x> stow_private<Tag,x>::instance;
template<class C, class T>
struct mem_data_ptr { typedef T(C::*type); };
#define MIGRAPH_ROB(name, Type, C, mem) \
struct name ## _tag \
: mem_data_ptr<C, Type> \
{}; \
template struct stow_private<name ## _tag,&C::mem>; \
template<class T> \
auto& name(T&& x) \
{ \
return x.*stowed<name ## _tag>::value; \
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif
......@@ -2,6 +2,7 @@
#include <migraph/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <rob.hpp>
void simple_test()
{
......@@ -38,6 +39,8 @@ void incomplete_args()
EXPECT(bool{p.validate() == ins});
}
MIGRAPH_ROB(access_ins_arguments, std::vector<migraph::instruction_ref>, migraph::instruction, arguments)
void invalid_args()
{
migraph::program p;
......@@ -45,7 +48,7 @@ void invalid_args()
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two);
ins->arguments.clear();
access_ins_arguments(*ins).clear();
EXPECT(bool{p.validate() == p.begin()});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment