Commit 767d2885 authored by Paul's avatar Paul
Browse files

Make fields private in instruction class

parent 50361163
...@@ -37,7 +37,7 @@ void eliminate_contiguous::apply(program& p) const ...@@ -37,7 +37,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args; auto new_args = args;
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args)) if(try_compute_shape(ins->get_operator(), new_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -28,12 +28,12 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -28,12 +28,12 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
const auto& mean = ins->inputs()[3]->get_literal(); const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal(); const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op); auto bn_op = any_cast<batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights // Get convolution weights
const auto& weights = conv_ins->inputs()[1]->get_literal(); const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op // Get convolution op
auto conv_op = conv_ins->op; auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
......
...@@ -138,6 +138,12 @@ struct instruction ...@@ -138,6 +138,12 @@ struct instruction
ins->recompute_shape(); ins->recompute_shape();
} }
static void replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args)
{
ins->replace(o, r, std::move(args));
backreference(ins);
}
private:
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args) void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{ {
......
...@@ -76,7 +76,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80) ...@@ -76,7 +76,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
inputs.push_back( inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape())); p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
} }
p.add_instruction(ins.op, inputs); p.add_instruction(ins.get_operator(), inputs);
return p; return p;
}; };
try try
......
...@@ -20,7 +20,7 @@ struct program_impl ...@@ -20,7 +20,7 @@ struct program_impl
context ctx; context ctx;
}; };
const operation& get_operation(instruction_ref ins) { return ins->op; } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
template <class F> template <class F>
static void print_program(std::ostream& os, const program& p, F annonate) static void print_program(std::ostream& os, const program& p, F annonate)
...@@ -33,12 +33,12 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -33,12 +33,12 @@ static void print_program(std::ostream& os, const program& p, F annonate)
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->op).parameter; var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
} }
os << var_name << " = "; os << var_name << " = ";
os << ins->op; os << ins->get_operator();
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
...@@ -108,8 +108,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, ...@@ -108,8 +108,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
ins->replace(op, r, std::move(args)); instruction::replace(ins, op, r, std::move(args));
instruction::backreference(ins);
assert(ins->valid(begin())); assert(ins->valid(begin()));
return ins; return ins;
} }
...@@ -190,7 +189,7 @@ shape program::get_parameter_shape(std::string name) const ...@@ -190,7 +189,7 @@ shape program::get_parameter_shape(std::string name) const
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param") if(x.name() == "@param")
{ {
return any_cast<builtin::param>(x.op).parameter == name; return any_cast<builtin::param>(x.get_operator()).parameter == name;
} }
else else
{ {
...@@ -210,7 +209,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const ...@@ -210,7 +209,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{ {
if(ins.name() == "@param") if(ins.name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins.op).parameter; auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape(); result[name] = ins.get_shape();
} }
} }
...@@ -291,7 +290,7 @@ argument generic_eval(const program& p, ...@@ -291,7 +290,7 @@ argument generic_eval(const program& p,
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
return params.at(any_cast<builtin::param>(ins->op).parameter); return params.at(any_cast<builtin::param>(ins->get_operator()).parameter);
})); }));
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
...@@ -307,7 +306,7 @@ argument generic_eval(const program& p, ...@@ -307,7 +306,7 @@ argument generic_eval(const program& p,
return results[i]; return results[i];
}); });
results.emplace( results.emplace(
ins, trace(ins, [&] { return ins->op.compute(ctx, ins->get_shape(), values); })); ins, trace(ins, [&] { return ins->get_operator().compute(ctx, ins->get_shape(), values); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
......
...@@ -603,20 +603,20 @@ struct cpu_apply ...@@ -603,20 +603,20 @@ struct cpu_apply
template <class T, class Op> template <class T, class Op>
void apply_extend_op(instruction_ref ins) void apply_extend_op(instruction_ref ins)
{ {
auto&& op = any_cast<Op>(ins->op); auto&& op = any_cast<Op>(ins->get_operator());
prog->replace_instruction(ins, T{op}, ins->inputs()); prog->replace_instruction(ins, T{op}, ins->inputs());
} }
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->op); auto&& op = any_cast<activation>(ins->get_operator());
if(op.mode == "relu") if(op.mode == "relu")
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->inputs()); prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->inputs());
} }
void apply_pooling(instruction_ref ins) void apply_pooling(instruction_ref ins)
{ {
auto&& op = any_cast<pooling>(ins->op); auto&& op = any_cast<pooling>(ins->get_operator());
if(op.mode == "max") if(op.mode == "max")
prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs()); prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average") else if(op.mode == "average")
......
...@@ -20,7 +20,7 @@ void eliminate_workspace::apply(program& p) const ...@@ -20,7 +20,7 @@ void eliminate_workspace::apply(program& p) const
continue; continue;
if(ins->name() != "hip::allocate") if(ins->name() != "hip::allocate")
continue; continue;
auto&& a = any_cast<hip_allocate>(ins->op); auto&& a = any_cast<hip_allocate>(ins->get_operator());
if(a.tag == "workspace") if(a.tag == "workspace")
{ {
n = std::max(n, ins->get_shape().bytes()); n = std::max(n, ins->get_shape().bytes());
......
...@@ -369,7 +369,7 @@ struct miopen_apply ...@@ -369,7 +369,7 @@ struct miopen_apply
instruction_ref apply_convolution(instruction_ref ins) instruction_ref apply_convolution(instruction_ref ins)
{ {
auto&& op = any_cast<convolution>(ins->op); auto&& op = any_cast<convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs()); auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs());
...@@ -383,7 +383,7 @@ struct miopen_apply ...@@ -383,7 +383,7 @@ struct miopen_apply
instruction_ref apply_pooling(instruction_ref ins) instruction_ref apply_pooling(instruction_ref ins)
{ {
auto&& op = any_cast<pooling>(ins->op); auto&& op = any_cast<pooling>(ins->get_operator());
auto pd = make_pooling(op); auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -393,7 +393,7 @@ struct miopen_apply ...@@ -393,7 +393,7 @@ struct miopen_apply
instruction_ref apply_activation(instruction_ref ins) instruction_ref apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->op); auto&& op = any_cast<activation>(ins->get_operator());
auto ad = make_relu(); auto ad = make_relu();
if(op.mode == "relu") if(op.mode == "relu")
{ {
...@@ -413,7 +413,7 @@ struct miopen_apply ...@@ -413,7 +413,7 @@ struct miopen_apply
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_gemm(instruction_ref ins)
{ {
auto&& op = any_cast<gemm>(ins->op); auto&& op = any_cast<gemm>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output); ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
...@@ -421,14 +421,14 @@ struct miopen_apply ...@@ -421,14 +421,14 @@ struct miopen_apply
instruction_ref apply_contiguous(instruction_ref ins) instruction_ref apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->op); auto&& op = any_cast<contiguous>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output); return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
} }
instruction_ref apply_batch_norm_inference(instruction_ref ins) instruction_ref apply_batch_norm_inference(instruction_ref ins)
{ {
auto&& op = any_cast<batch_norm_inference>(ins->op); auto&& op = any_cast<batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->inputs().at(1)->get_shape(); shape old_shape = ins->inputs().at(1)->get_shape();
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1}; std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
......
...@@ -21,13 +21,13 @@ struct reverse_pass ...@@ -21,13 +21,13 @@ struct reverse_pass
{ {
for(auto ins : migraph::iterator_for(p)) for(auto ins : migraph::iterator_for(p))
{ {
if(ins->op.name() == "sum") if(ins->name() == "sum")
{ {
p.replace_instruction(ins, minus_op{}, ins->arguments); p.replace_instruction(ins, minus_op{}, ins->inputs());
} }
else if(ins->op.name() == "minus") else if(ins->name() == "minus")
{ {
p.replace_instruction(ins, sum_op{}, ins->arguments); p.replace_instruction(ins, sum_op{}, ins->inputs());
} }
} }
} }
......
#ifndef MIGRAPH_GUARD_ROB_HPP
#define MIGRAPH_GUARD_ROB_HPP
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
// Used to access private member variables
template <class Tag>
struct stowed
{
static typename Tag::type value;
};
template <class Tag>
typename Tag::type stowed<Tag>::value;
template <class Tag, typename Tag::type x>
struct stow_private
{
stow_private() { stowed<Tag>::value = x; }
static stow_private instance;
};
template <class Tag, typename Tag::type x>
stow_private<Tag,x> stow_private<Tag,x>::instance;
template<class C, class T>
struct mem_data_ptr { typedef T(C::*type); };
#define MIGRAPH_ROB(name, Type, C, mem) \
struct name ## _tag \
: mem_data_ptr<C, Type> \
{}; \
template struct stow_private<name ## _tag,&C::mem>; \
template<class T> \
auto& name(T&& x) \
{ \
return x.*stowed<name ## _tag>::value; \
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
#include <rob.hpp>
void simple_test() void simple_test()
{ {
...@@ -38,6 +39,8 @@ void incomplete_args() ...@@ -38,6 +39,8 @@ void incomplete_args()
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
MIGRAPH_ROB(access_ins_arguments, std::vector<migraph::instruction_ref>, migraph::instruction, arguments)
void invalid_args() void invalid_args()
{ {
migraph::program p; migraph::program p;
...@@ -45,7 +48,7 @@ void invalid_args() ...@@ -45,7 +48,7 @@ void invalid_args()
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two); auto ins = p.add_instruction(sum_op{}, one, two);
ins->arguments.clear(); access_ins_arguments(*ins).clear();
EXPECT(bool{p.validate() == p.begin()}); EXPECT(bool{p.validate() == p.begin()});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment