"tests/models/videomae/test_image_processing_videomae.py" did not exist on "672b66262aa4e65863f8aa94743fdd3c2a27a10b"
Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -8,7 +8,7 @@ namespace migraphx { ...@@ -8,7 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void sync_device::apply(program& p) const void sync_device::apply(module& p) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(p.end());
if(last->name() == "@return") if(last->name() == "@return")
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
...@@ -10,7 +11,7 @@ namespace gpu { ...@@ -10,7 +11,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(program& p) const void write_literals::apply(module& p) const
{ {
assert(ctx != nullptr); assert(ctx != nullptr);
std::size_t n = 0; std::size_t n = 0;
......
...@@ -11,7 +11,7 @@ namespace ref { ...@@ -11,7 +11,7 @@ namespace ref {
struct lowering struct lowering
{ {
std::string name() const { return "ref::lowering"; } std::string name() const { return "ref::lowering"; }
void apply(program& p) const; void apply(module& m) const;
}; };
} // namespace ref } // namespace ref
......
...@@ -882,7 +882,7 @@ MIGRAPHX_REGISTER_OP(ref_rnn_var_sl_last_output) ...@@ -882,7 +882,7 @@ MIGRAPHX_REGISTER_OP(ref_rnn_var_sl_last_output)
struct ref_apply struct ref_apply
{ {
program* prog; module* modl;
std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};
template <class T> template <class T>
...@@ -922,7 +922,7 @@ struct ref_apply ...@@ -922,7 +922,7 @@ struct ref_apply
void apply() void apply()
{ {
init(); init();
for(auto it : iterator_for(*prog)) for(auto it : iterator_for(*modl))
{ {
if(it->name() == "pooling") if(it->name() == "pooling")
{ {
...@@ -941,33 +941,33 @@ struct ref_apply ...@@ -941,33 +941,33 @@ struct ref_apply
void apply_ref_op(instruction_ref ins) const void apply_ref_op(instruction_ref ins) const
{ {
prog->replace_instruction(ins, ref_op{ins->get_operator()}, ins->inputs()); modl->replace_instruction(ins, ref_op{ins->get_operator()}, ins->inputs());
} }
template <class T> template <class T>
void apply_simple_op(instruction_ref ins) void apply_simple_op(instruction_ref ins)
{ {
prog->replace_instruction(ins, T{}, ins->inputs()); modl->replace_instruction(ins, T{}, ins->inputs());
} }
template <class T, class Op> template <class T, class Op>
void apply_extend_op(instruction_ref ins) void apply_extend_op(instruction_ref ins)
{ {
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = any_cast<Op>(ins->get_operator());
prog->replace_instruction(ins, T{op}, ins->inputs()); modl->replace_instruction(ins, T{op}, ins->inputs());
} }
void apply_pooling(instruction_ref ins) const void apply_pooling(instruction_ref ins) const
{ {
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "max") if(op.mode == "max")
prog->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs()); modl->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average") else if(op.mode == "average")
prog->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs()); modl->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
} }
}; };
void lowering::apply(program& p) const { ref_apply{&p}.apply(); } void lowering::apply(module& m) const { ref_apply{&m}.apply(); }
} // namespace ref } // namespace ref
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -33,6 +33,7 @@ struct tf_parser ...@@ -33,6 +33,7 @@ struct tf_parser
std::vector<tensorflow::NodeDef> input_nodes; std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
module* mm = prog.get_main_module();
bool is_nhwc = true; bool is_nhwc = true;
unsigned int batch_size = 1; unsigned int batch_size = 1;
...@@ -43,33 +44,33 @@ struct tf_parser ...@@ -43,33 +44,33 @@ struct tf_parser
return is_nhwc and ins->get_shape().lens().size() == 4; return is_nhwc and ins->get_shape().lens().size() == 4;
} }
instruction_ref to_nhwc(instruction_ref ins) instruction_ref to_nhwc(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins); return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins; return ins;
} }
instruction_ref to_nchw(instruction_ref ins) instruction_ref to_nchw(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins); return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins; return ins;
} }
instruction_ref to_kcxy(instruction_ref ins) instruction_ref to_kcxy(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins); return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins; return ins;
} }
instruction_ref make_contiguous(instruction_ref ins) instruction_ref make_contiguous(instruction_ref ins) const
{ {
if(ins->get_shape().standard()) if(ins->get_shape().standard())
return ins; return ins;
else else
return prog.add_instruction(op::contiguous{}, ins); return mm->add_instruction(op::contiguous{}, ins);
} }
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args) std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
...@@ -266,7 +267,7 @@ struct tf_parser ...@@ -266,7 +267,7 @@ struct tf_parser
// { // {
// if(is_nhwc) // if(is_nhwc)
// { // {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]); // l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// } // }
// } // }
return add_broadcastable_binary_op(args[0], args[1], x); return add_broadcastable_binary_op(args[0], args[1], x);
...@@ -308,13 +309,13 @@ struct tf_parser ...@@ -308,13 +309,13 @@ struct tf_parser
output_lens.begin() + offset, output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1);
return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1))); return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
} }
else else
{ {
return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)})); return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
} }
} }
...@@ -323,7 +324,7 @@ struct tf_parser ...@@ -323,7 +324,7 @@ struct tf_parser
{ {
add_op(name, add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) { [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return mm->add_instruction(x, args);
}, },
transpose); transpose);
} }
...@@ -334,12 +335,13 @@ struct tf_parser ...@@ -334,12 +335,13 @@ struct tf_parser
{ {
int64_t axis = 0; int64_t axis = 0;
axis = args[1]->eval().at<int64_t>(); axis = args[1]->eval().at<int64_t>();
auto ins = prog.add_instruction(Op{axis}, args.front()); auto ins = mm->add_instruction(Op{axis}, args.front());
return prog.add_instruction(op::squeeze{{axis}}, ins); return mm->add_instruction(op::squeeze{{axis}}, ins);
} }
instruction_ref instruction_ref parse_batchnorm(const std::string&,
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
float epsilon = 1e-5f; float epsilon = 1e-5f;
float momentum = 0.9f; float momentum = 0.9f;
...@@ -349,46 +351,49 @@ struct tf_parser ...@@ -349,46 +351,49 @@ struct tf_parser
epsilon = attributes.at("epsilon").f(); epsilon = attributes.at("epsilon").f();
} }
op::batch_norm_inference op{epsilon, momentum, bn_mode}; op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args)); return mm->add_instruction(op, std::move(args));
} }
instruction_ref instruction_ref
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]); auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0); return mm->add_instruction(op::add{}, args[0], l0);
} }
instruction_ref instruction_ref parse_cast(const std::string&,
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
shape::type_t type = parse_type(attributes.at("DstT").type()); shape::type_t type = parse_type(attributes.at("DstT").type());
return prog.add_instruction(op::convert{type}, std::move(args)); return mm->add_instruction(op::convert{type}, std::move(args));
} }
instruction_ref instruction_ref parse_concat(const std::string&,
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
int64_t axis = args[axis_idx]->eval().at<int64_t>(); int64_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return mm->add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1)); op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
} }
instruction_ref parse_constant(const std::string&, instruction_ref parse_constant(const std::string&,
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&) const
{ {
literal v = parse_tensor(attributes.at("value").tensor()); literal v = parse_tensor(attributes.at("value").tensor());
return prog.add_literal(v); return mm->add_literal(v);
} }
instruction_ref instruction_ref parse_conv(const std::string&,
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
op::convolution op; op::convolution op;
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
...@@ -436,7 +441,7 @@ struct tf_parser ...@@ -436,7 +441,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] || pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding}, l0); l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
} }
else else
{ {
...@@ -464,12 +469,12 @@ struct tf_parser ...@@ -464,12 +469,12 @@ struct tf_parser
op.padding[1] = padding[1]; op.padding[1] = padding[1];
} }
} }
return prog.add_instruction(op, {l0, to_kcxy(args[1])}); return mm->add_instruction(op, {l0, to_kcxy(args[1])});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
op::convolution op; op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1]; size_t num_channels = args[0]->get_shape().lens()[1];
...@@ -522,7 +527,7 @@ struct tf_parser ...@@ -522,7 +527,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] || pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding}, l0); l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
} }
else else
{ {
...@@ -548,13 +553,14 @@ struct tf_parser ...@@ -548,13 +553,14 @@ struct tf_parser
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto new_weights = auto new_weights =
prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights)); mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {l0, new_weights}); return mm->add_instruction(op, {l0, new_weights});
} }
instruction_ref instruction_ref parse_expanddims(const std::string&,
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const attribute_map&,
std::vector<instruction_ref> args) const
{ {
std::vector<size_t> input_dims = args[0]->get_shape().lens(); std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end()); std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
...@@ -569,19 +575,20 @@ struct tf_parser ...@@ -569,19 +575,20 @@ struct tf_parser
{ {
new_dims.insert(new_dims.begin() + dim, 1); new_dims.insert(new_dims.begin() + dim, 1);
} }
return prog.add_instruction(op::reshape{new_dims}, args[0]); return mm->add_instruction(op::reshape{new_dims}, args[0]);
} }
instruction_ref instruction_ref
parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
int axis = args[2]->eval().at<int32_t>(); int axis = args[2]->eval().at<int32_t>();
op::gather op{axis}; op::gather op{axis};
return prog.add_instruction(op, {args[0], args[1]}); return mm->add_instruction(op, {args[0], args[1]});
} }
instruction_ref instruction_ref parse_matmul(const std::string&,
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
bool transa = false; bool transa = false;
bool transb = false; bool transb = false;
...@@ -609,31 +616,33 @@ struct tf_parser ...@@ -609,31 +616,33 @@ struct tf_parser
// swap the last two elements // swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2); std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
return prog.add_instruction(op::dot{}, l1, l2); return mm->add_instruction(op::dot{}, l1, l2);
} }
instruction_ref instruction_ref parse_mean(const std::string&,
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>(); auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
if(keep_dims) if(keep_dims)
{ {
return prog.add_instruction(op::reduce_mean{axes}, args[0]); return mm->add_instruction(op::reduce_mean{axes}, args[0]);
} }
else else
{ {
auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]); auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]);
return prog.add_instruction(op::squeeze{axes}, ins); return mm->add_instruction(op::squeeze{axes}, ins);
} }
} }
instruction_ref instruction_ref parse_onehot(const std::string&,
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) attribute_map attributes,
std::vector<instruction_ref> args) const
{ {
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>()); size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
...@@ -652,15 +661,15 @@ struct tf_parser ...@@ -652,15 +661,15 @@ struct tf_parser
if(axis == -1) if(axis == -1)
{ {
shape s{shape::float_type, {depth, depth}}; shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input}); auto l0 = mm->add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]}); return mm->add_instruction(op::gather{0}, {l0, args[0]});
} }
MIGRAPHX_THROW("MIGraphX does not support axis != -1"); MIGRAPHX_THROW("MIGraphX does not support axis != -1");
} }
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
// reinterpret as unsqueeze with concat // reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args; std::vector<instruction_ref> unsqueezed_args;
...@@ -678,12 +687,12 @@ struct tf_parser ...@@ -678,12 +687,12 @@ struct tf_parser
args.begin(), args.begin(),
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(prog.add_instruction(op::concat{axis}, unsqueezed_args)); return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args));
} }
instruction_ref instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
size_t ndims = args.front()->get_shape().lens().size(); size_t ndims = args.front()->get_shape().lens().size();
...@@ -706,12 +715,12 @@ struct tf_parser ...@@ -706,12 +715,12 @@ struct tf_parser
pads[i + ndims] = pad_per_dim[i].second; pads[i + ndims] = pad_per_dim[i].second;
} }
op.pads = pads; op.pads = pads;
return prog.add_instruction(op, args.front()); return mm->add_instruction(op, args.front());
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
op::pooling op{starts_with(name, "Max") ? "max" : "average"}; op::pooling op{starts_with(name, "Max") ? "max" : "average"};
...@@ -754,7 +763,7 @@ struct tf_parser ...@@ -754,7 +763,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] || pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction( l0 = mm->add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0); migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
} }
else else
...@@ -764,47 +773,47 @@ struct tf_parser ...@@ -764,47 +773,47 @@ struct tf_parser
} }
} }
} }
return prog.add_instruction(op, l0); return mm->add_instruction(op, l0);
} }
instruction_ref instruction_ref
parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto min_val = prog.add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = prog.add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = prog.add_instruction(op::multibroadcast{input_lens}, min_val); min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = prog.add_instruction(op::multibroadcast{input_lens}, max_val); max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val);
return prog.add_instruction(op::clip{}, args.front(), min_val, max_val); return mm->add_instruction(op::clip{}, args.front(), min_val, max_val);
} }
instruction_ref instruction_ref
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
op::reshape op; op::reshape op;
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval(); auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, make_contiguous(args[0])); return mm->add_instruction(op, make_contiguous(args[0]));
} }
// Use a literal instruction to replace the shape since output of // Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx // shape operator are literals in migraphx
instruction_ref instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size()); std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()}); migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
std::transform( std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; }); arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return prog.add_literal(migraphx::literal{s, vec_shape}); return mm->add_literal(migraphx::literal{s, vec_shape});
} }
instruction_ref instruction_ref
parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
op::slice op; op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
...@@ -823,7 +832,7 @@ struct tf_parser ...@@ -823,7 +832,7 @@ struct tf_parser
else else
op.ends[i] = starts[i] + size[i]; op.ends[i] = starts[i] + size[i];
} }
return prog.add_instruction(op, make_contiguous(args[0])); return mm->add_instruction(op, make_contiguous(args[0]));
} }
// template to facilitate the logsoftmax later // template to facilitate the logsoftmax later
...@@ -843,12 +852,12 @@ struct tf_parser ...@@ -843,12 +852,12 @@ struct tf_parser
axis += num_dims; axis += num_dims;
} }
return prog.add_instruction(Op{axis}, make_contiguous(args[0])); return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
} }
std::vector<instruction_ref> parse_split(const std::string&, std::vector<instruction_ref> parse_split(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
bool vector_as_input = args.size() == 3; bool vector_as_input = args.size() == 3;
int num_outputs = 1; int num_outputs = 1;
...@@ -874,7 +883,7 @@ struct tf_parser ...@@ -874,7 +883,7 @@ struct tf_parser
assert(num_outputs > 0); assert(num_outputs > 0);
if(num_outputs == 1) if(num_outputs == 1)
return std::vector<instruction_ref>{prog.add_instruction(op::identity{}, input_arg)}; return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)};
auto lens = input_arg->get_shape().lens(); auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size(); auto num_dims = lens.size();
...@@ -919,14 +928,14 @@ struct tf_parser ...@@ -919,14 +928,14 @@ struct tf_parser
op.starts[axis] = slice_pos[i]; op.starts[axis] = slice_pos[i];
op.ends[axis] = slice_pos[i + 1]; op.ends[axis] = slice_pos[i + 1];
result.push_back(prog.add_instruction(op, input_arg)); result.push_back(mm->add_instruction(op, input_arg));
} }
return result; return result;
} }
instruction_ref parse_squeeze(const std::string&, instruction_ref parse_squeeze(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args) const
{ {
op::squeeze op; op::squeeze op;
auto input_dims = args[0]->get_shape().lens(); auto input_dims = args[0]->get_shape().lens();
...@@ -943,7 +952,7 @@ struct tf_parser ...@@ -943,7 +952,7 @@ struct tf_parser
} }
} }
} }
return prog.add_instruction(op, make_contiguous(args[0])); return mm->add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref parse_stridedslice(const std::string&, instruction_ref parse_stridedslice(const std::string&,
...@@ -991,7 +1000,7 @@ struct tf_parser ...@@ -991,7 +1000,7 @@ struct tf_parser
} }
} }
auto l1 = prog.add_instruction(op, l0); auto l1 = mm->add_instruction(op, l0);
if(shrink_axis_mask == 0) if(shrink_axis_mask == 0)
return l1; return l1;
...@@ -1002,17 +1011,18 @@ struct tf_parser ...@@ -1002,17 +1011,18 @@ struct tf_parser
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
return prog.add_instruction(op::squeeze{squeeze_axes}, l1); return mm->add_instruction(op::squeeze{squeeze_axes}, l1);
} }
instruction_ref instruction_ref parse_transpose(const std::string&,
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const attribute_map&,
std::vector<instruction_ref> args) const
{ {
auto perm = args[1]->eval().get<int32_t>().to_vector(); auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op; op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end()); op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front()); return mm->add_instruction(op, args.front());
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
...@@ -1032,7 +1042,7 @@ struct tf_parser ...@@ -1032,7 +1042,7 @@ struct tf_parser
return static_cast<int>(dim) <= 0 ? batch_size : dim; return static_cast<int>(dim) <= 0 ? batch_size : dim;
}); });
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(prog.add_parameter(name, s)); instructions[name] = to_nhwc(mm->add_parameter(name, s));
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -1083,7 +1093,7 @@ struct tf_parser ...@@ -1083,7 +1093,7 @@ struct tf_parser
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0) if(ops.count(node.op()) == 0)
{ {
result.push_back(prog.add_instruction(op::unknown{node.op()}, args)); result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
} }
else else
{ {
...@@ -1405,7 +1415,7 @@ program parse_tf(const std::string& name, tf_options options) ...@@ -1405,7 +1415,7 @@ program parse_tf(const std::string& name, tf_options options)
#else #else
parser.parse_from(input); parser.parse_from(input);
#endif #endif
parser.to_nchw(std::prev(parser.prog.end())); parser.to_nchw(std::prev(parser.mm->end()));
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -75,26 +75,27 @@ struct test_stream_model ...@@ -75,26 +75,27 @@ struct test_stream_model
struct program_model struct program_model
{ {
migraphx::program p; migraphx::program p;
migraphx::module* mm = p.get_main_module();
std::unordered_map<migraphx::instruction_ref, std::size_t> ins2stream{}; std::unordered_map<migraphx::instruction_ref, std::size_t> ins2stream{};
std::size_t max_stream = 0; std::size_t max_stream = 0;
template <class... Ts> template <class... Ts>
migraphx::instruction_ref add_literal(Ts... xs) migraphx::instruction_ref add_literal(Ts... xs)
{ {
return p.add_literal(xs...); return mm->add_literal(xs...);
} }
template <class... Ts> template <class... Ts>
migraphx::instruction_ref add_instruction(Ts... xs) migraphx::instruction_ref add_instruction(Ts... xs)
{ {
return p.add_instruction(xs...); return mm->add_instruction(xs...);
} }
template <class... Ts> template <class... Ts>
migraphx::instruction_ref add_instruction_stream(std::size_t n, Ts... xs) migraphx::instruction_ref add_instruction_stream(std::size_t n, Ts... xs)
{ {
max_stream = std::max(max_stream, n); max_stream = std::max(max_stream, n);
auto ins = p.add_instruction(xs...); auto ins = mm->add_instruction(xs...);
ins2stream[ins] = n; ins2stream[ins] = n;
return ins; return ins;
} }
...@@ -102,14 +103,14 @@ struct program_model ...@@ -102,14 +103,14 @@ struct program_model
template <class... Ts> template <class... Ts>
migraphx::instruction_ref add_return(Ts... xs) migraphx::instruction_ref add_return(Ts... xs)
{ {
return p.add_return({xs...}); return mm->add_return({xs...});
} }
template <class... Ts> template <class... Ts>
migraphx::instruction_ref add_return_stream(std::size_t n, Ts... xs) migraphx::instruction_ref add_return_stream(std::size_t n, Ts... xs)
{ {
max_stream = std::max(max_stream, n); max_stream = std::max(max_stream, n);
auto ins = p.add_return({xs...}); auto ins = mm->add_return({xs...});
ins2stream[ins] = n; ins2stream[ins] = n;
return ins; return ins;
} }
...@@ -118,7 +119,7 @@ struct program_model ...@@ -118,7 +119,7 @@ struct program_model
std::vector<migraphx::stream_race> analyze() const std::vector<migraphx::stream_race> analyze() const
{ {
return migraphx::analyze_streams(p, get_stream_model()); return migraphx::analyze_streams(*p.get_main_module(), get_stream_model());
} }
void debug_print() const { p.debug_print(); } void debug_print() const { p.debug_print(); }
......
...@@ -6,13 +6,18 @@ ...@@ -6,13 +6,18 @@
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::auto_contiguous{}}); } void run_pass(migraphx::program& p)
{
migraphx::run_passes(*p.get_main_module(), {migraphx::auto_contiguous{}});
}
// TODO: Add this test case // TODO: Add this test case
void literal_broadcast() void literal_broadcast()
{ {
migraphx::program p; migraphx::program p;
p.add_literal(get_2_broadcasted());
auto* mm = p.get_main_module();
mm->add_literal(get_2_broadcasted());
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(p.get_output_shapes().back().broadcasted());
run_pass(p); run_pass(p);
...@@ -23,7 +28,9 @@ void literal_broadcast() ...@@ -23,7 +28,9 @@ void literal_broadcast()
TEST_CASE(literal_transpose) TEST_CASE(literal_transpose)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(get_2x2_transposed());
auto* mm = p.get_main_module();
mm->add_literal(get_2x2_transposed());
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed()); EXPECT(p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -34,11 +41,13 @@ TEST_CASE(literal_transpose) ...@@ -34,11 +41,13 @@ TEST_CASE(literal_transpose)
TEST_CASE(after_literal_transpose) TEST_CASE(after_literal_transpose)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t); mm->add_instruction(pass_op{}, t);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed()); EXPECT(p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -49,12 +58,14 @@ TEST_CASE(after_literal_transpose) ...@@ -49,12 +58,14 @@ TEST_CASE(after_literal_transpose)
TEST_CASE(after_literal_broadcast) TEST_CASE(after_literal_broadcast)
{ {
migraphx::program p; migraphx::program p;
auto l1 = p.add_literal(get_2x2());
auto l2 = p.add_literal(get_2()); auto* mm = p.get_main_module();
auto l1 = mm->add_literal(get_2x2());
auto l2 = mm->add_literal(get_2());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted()); EXPECT(not p.get_output_shapes().back().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2); auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); mm->add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(p.get_output_shapes().back().broadcasted());
run_pass(p); run_pass(p);
...@@ -65,11 +76,13 @@ TEST_CASE(after_literal_broadcast) ...@@ -65,11 +76,13 @@ TEST_CASE(after_literal_broadcast)
TEST_CASE(after_param_transpose) TEST_CASE(after_param_transpose)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto* mm = p.get_main_module();
auto l = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t); mm->add_instruction(pass_op{}, t);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed()); EXPECT(p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
...@@ -80,12 +93,14 @@ TEST_CASE(after_param_transpose) ...@@ -80,12 +93,14 @@ TEST_CASE(after_param_transpose)
TEST_CASE(after_param_broadcast) TEST_CASE(after_param_broadcast)
{ {
migraphx::program p; migraphx::program p;
auto l1 = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}}); auto* mm = p.get_main_module();
auto l1 = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted()); EXPECT(not p.get_output_shapes().back().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2); auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); mm->add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(p.get_output_shapes().back().broadcasted());
run_pass(p); run_pass(p);
......
...@@ -52,42 +52,52 @@ struct test_context ...@@ -52,42 +52,52 @@ struct test_context
TEST_CASE(literal_test) TEST_CASE(literal_test)
{ {
migraphx::program p; migraphx::program p;
auto lit = p.add_literal(1);
auto* mm = p.get_main_module();
auto lit = mm->add_literal(1);
CHECK(lit->eval() == migraphx::literal{1}); CHECK(lit->eval() == migraphx::literal{1});
} }
TEST_CASE(param_test) TEST_CASE(param_test)
{ {
migraphx::program p; migraphx::program p;
auto lit = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto* mm = p.get_main_module();
auto lit = mm->add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
CHECK(lit->eval().empty()); CHECK(lit->eval().empty());
} }
TEST_CASE(op_test1) TEST_CASE(op_test1)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2); auto* mm = p.get_main_module();
auto sum = p.add_instruction(sum_cf_op{}, one, two); auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_cf_op{}, one, two);
CHECK(sum->eval() == migraphx::literal{3}); CHECK(sum->eval() == migraphx::literal{3});
} }
TEST_CASE(op_test2) TEST_CASE(op_test2)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto two = p.add_literal(2); auto* mm = p.get_main_module();
auto sum = p.add_instruction(sum_cf_op{}, x, two); auto x = mm->add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_cf_op{}, x, two);
CHECK(sum->eval().empty()); CHECK(sum->eval().empty());
} }
TEST_CASE(op_test3) TEST_CASE(op_test3)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2); auto* mm = p.get_main_module();
auto sum1 = p.add_instruction(sum_op{}, one, two); auto one = mm->add_literal(1);
auto sum2 = p.add_instruction(sum_cf_op{}, sum1, two); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_cf_op{}, sum1, two);
CHECK(sum2->eval().empty()); CHECK(sum2->eval().empty());
} }
......
...@@ -8,16 +8,16 @@ ...@@ -8,16 +8,16 @@
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
} }
TEST_CASE(simple_test) TEST_CASE(simple_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
...@@ -29,11 +29,11 @@ TEST_CASE(simple_test) ...@@ -29,11 +29,11 @@ TEST_CASE(simple_test)
TEST_CASE(simple_test_nop) TEST_CASE(simple_test_nop)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(nop{}); mm->add_instruction(nop{});
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
...@@ -45,12 +45,12 @@ TEST_CASE(simple_test_nop) ...@@ -45,12 +45,12 @@ TEST_CASE(simple_test_nop)
TEST_CASE(simple_test_nop2) TEST_CASE(simple_test_nop2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(nop{}); mm->add_instruction(nop{});
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.add_instruction(nop{}); mm->add_instruction(nop{});
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -61,11 +61,11 @@ TEST_CASE(simple_test_nop2) ...@@ -61,11 +61,11 @@ TEST_CASE(simple_test_nop2)
TEST_CASE(duplicate_test1) TEST_CASE(duplicate_test1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
...@@ -77,12 +77,12 @@ TEST_CASE(duplicate_test1) ...@@ -77,12 +77,12 @@ TEST_CASE(duplicate_test1)
TEST_CASE(duplicate_test2) TEST_CASE(duplicate_test2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, one, two); mm->add_instruction(minus_op{}, one, two);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 2)); EXPECT(std::distance(p.begin(), p.end()) == (count - 2));
...@@ -94,14 +94,14 @@ TEST_CASE(duplicate_test2) ...@@ -94,14 +94,14 @@ TEST_CASE(duplicate_test2)
TEST_CASE(depth_test) TEST_CASE(depth_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto x1 = p.add_instruction(sum_op{}, one, two); auto x1 = mm->add_instruction(sum_op{}, one, two);
auto x2 = p.add_instruction(sum_op{}, one, two); auto x2 = mm->add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, x1, x2); mm->add_instruction(minus_op{}, x1, x2);
p.add_instruction(minus_op{}, x1, x2); mm->add_instruction(minus_op{}, x1, x2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 4)); EXPECT(std::distance(p.begin(), p.end()) == (count - 4));
...@@ -113,15 +113,15 @@ TEST_CASE(depth_test) ...@@ -113,15 +113,15 @@ TEST_CASE(depth_test)
TEST_CASE(undefined_test) TEST_CASE(undefined_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto undef = p.add_instruction(migraphx::op::undefined{}); auto undef = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1); EXPECT(std::distance(p.begin(), p.end()) == count - 1);
EXPECT(not p.has_instruction(undef)); EXPECT(not mm->has_instruction(undef));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -130,11 +130,11 @@ TEST_CASE(undefined_test) ...@@ -130,11 +130,11 @@ TEST_CASE(undefined_test)
TEST_CASE(duplicate_args1) TEST_CASE(duplicate_args1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = p.add_literal(0); auto l0 = mm->add_literal(0);
auto l3 = p.add_literal(3); auto l3 = mm->add_literal(3);
p.add_instruction(migraphx::op::add{}, l3, l3); mm->add_instruction(migraphx::op::add{}, l3, l3);
p.add_instruction(migraphx::op::identity{}, l0); mm->add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(p.begin(), p.end()) != count);
...@@ -146,12 +146,12 @@ TEST_CASE(duplicate_args1) ...@@ -146,12 +146,12 @@ TEST_CASE(duplicate_args1)
TEST_CASE(duplicate_args2) TEST_CASE(duplicate_args2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = p.add_literal(0); auto l0 = mm->add_literal(0);
auto l3 = p.add_literal(3); auto l3 = mm->add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3); auto sum1 = mm->add_instruction(migraphx::op::add{}, l0, l3);
p.add_instruction(migraphx::op::add{}, sum1, l3); mm->add_instruction(migraphx::op::add{}, sum1, l3);
p.add_instruction(migraphx::op::identity{}, l0); mm->add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(p.begin(), p.end()) != count);
...@@ -163,13 +163,13 @@ TEST_CASE(duplicate_args2) ...@@ -163,13 +163,13 @@ TEST_CASE(duplicate_args2)
TEST_CASE(duplicate_args3) TEST_CASE(duplicate_args3)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = p.add_literal(0); auto l0 = mm->add_literal(0);
auto l3 = p.add_literal(3); auto l3 = mm->add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3); auto sum1 = mm->add_instruction(migraphx::op::add{}, l0, l3);
auto sum2 = p.add_instruction(migraphx::op::add{}, l0, sum1); auto sum2 = mm->add_instruction(migraphx::op::add{}, l0, sum1);
p.add_instruction(migraphx::op::add{}, sum2, l3); mm->add_instruction(migraphx::op::add{}, sum2, l3);
p.add_instruction(migraphx::op::identity{}, l0); mm->add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(p.begin(), p.end()) != count);
......
...@@ -8,27 +8,32 @@ ...@@ -8,27 +8,32 @@
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::decompose{}}); } void run_pass(migraphx::program& p)
{
migraphx::run_passes(*p.get_main_module(), {migraphx::decompose{}});
}
TEST_CASE(dot_add) TEST_CASE(dot_add)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto* mm1 = p1.get_main_module();
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{}, x, y, z); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p1.add_instruction(migraphx::op::identity{}, dot); auto dot = mm1->add_instruction(migraphx::op::dot{}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto* mm2 = p2.get_main_module();
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto add = p2.add_instruction(migraphx::op::add{}, dot, z); auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
p2.add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::op::add{}, dot, z);
mm2->add_instruction(migraphx::op::identity{}, add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -37,25 +42,27 @@ TEST_CASE(dot_add_beta_float) ...@@ -37,25 +42,27 @@ TEST_CASE(dot_add_beta_float)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto* mm1 = p1.get_main_module();
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p1.add_instruction(migraphx::op::identity{}, dot); auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto* mm2 = p2.get_main_module();
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto beta = auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}}); auto beta = mm2->add_literal(
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast); auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul); auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast);
p2.add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul);
mm2->add_instruction(migraphx::op::identity{}, add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -64,25 +71,27 @@ TEST_CASE(dot_add_beta_half) ...@@ -64,25 +71,27 @@ TEST_CASE(dot_add_beta_half)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto* mm1 = p1.get_main_module();
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
p1.add_instruction(migraphx::op::identity{}, dot); auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto* mm2 = p2.get_main_module();
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta = auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); mm2->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast); auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul); auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add); mm2->add_instruction(migraphx::op::identity{}, add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -91,25 +100,27 @@ TEST_CASE(dot_add_beta_double) ...@@ -91,25 +100,27 @@ TEST_CASE(dot_add_beta_double)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto* mm1 = p1.get_main_module();
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
p1.add_instruction(migraphx::op::identity{}, dot); auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto* mm2 = p2.get_main_module();
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto beta = auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); auto beta = mm2->add_literal(
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast); auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul); auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast);
p2.add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul);
mm2->add_instruction(migraphx::op::identity{}, add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -118,11 +129,12 @@ TEST_CASE(dot_add_beta_int) ...@@ -118,11 +129,12 @@ TEST_CASE(dot_add_beta_int)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto* mm1 = p1.get_main_module();
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
p1.add_instruction(migraphx::op::identity{}, dot); auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
} }
migraphx::program p2 = p1; migraphx::program p2 = p1;
run_pass(p1); run_pass(p1);
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
void run_pass(migraphx::program& p, std::size_t align = 32) void run_pass(migraphx::program& p, std::size_t align = 32)
{ {
migraphx::run_passes( migraphx::run_passes(
p, {migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}}); *p.get_main_module(),
{migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}});
} }
struct allocate struct allocate
...@@ -39,14 +40,16 @@ struct allocate ...@@ -39,14 +40,16 @@ struct allocate
TEST_CASE(basic) TEST_CASE(basic)
{ {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}}); auto* mm = p.get_main_module();
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}});
p.add_instruction(pass_op{}, a3, p2); auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2);
run_pass(p); run_pass(p);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
...@@ -56,14 +59,16 @@ TEST_CASE(basic) ...@@ -56,14 +59,16 @@ TEST_CASE(basic)
TEST_CASE(aligned) TEST_CASE(aligned)
{ {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); auto* mm = p.get_main_module();
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2); mm->add_instruction(pass_op{}, a3, p2);
run_pass(p); run_pass(p);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
...@@ -73,14 +78,16 @@ TEST_CASE(aligned) ...@@ -73,14 +78,16 @@ TEST_CASE(aligned)
TEST_CASE(unaligned) TEST_CASE(unaligned)
{ {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); auto* mm = p.get_main_module();
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
p.add_instruction(pass_op{}, a3, p2); auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2);
run_pass(p, 1); run_pass(p, 1);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
...@@ -90,14 +97,16 @@ TEST_CASE(unaligned) ...@@ -90,14 +97,16 @@ TEST_CASE(unaligned)
TEST_CASE(float_aligned) TEST_CASE(float_aligned)
{ {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); auto* mm = p.get_main_module();
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2); mm->add_instruction(pass_op{}, a3, p2);
run_pass(p, 4); run_pass(p, 4);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
......
...@@ -8,29 +8,32 @@ ...@@ -8,29 +8,32 @@
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes( migraphx::run_passes(
p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); *p.get_main_module(),
{migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
} }
TEST_CASE(cse_test1) TEST_CASE(cse_test1)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto* mm1 = p1.get_main_module();
auto two = p1.add_literal(2); auto one = mm1->add_literal(1);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); auto two = mm1->add_literal(2);
auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum2 = mm1->add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum3); auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto* mm2 = p2.get_main_module();
auto two = p2.add_literal(2); auto one = mm2->add_literal(1);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); auto two = mm2->add_literal(2);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1); auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
p2.add_instruction(pass_op{}, sum3); auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -39,23 +42,25 @@ TEST_CASE(cse_test2) ...@@ -39,23 +42,25 @@ TEST_CASE(cse_test2)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto* mm1 = p1.get_main_module();
auto two = p1.add_literal(2); auto one = mm1->add_literal(1);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); auto two = mm1->add_literal(2);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one); auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one);
p1.add_instruction(pass_op{}, sum3); auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto* mm2 = p2.get_main_module();
auto two = p2.add_literal(2); auto one = mm2->add_literal(1);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); auto two = mm2->add_literal(2);
auto sum2 = p2.add_instruction(migraphx::op::add{}, two, one); auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum2 = mm2->add_instruction(migraphx::op::add{}, two, one);
p2.add_instruction(pass_op{}, sum3); auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -64,21 +69,23 @@ TEST_CASE(cse_test3) ...@@ -64,21 +69,23 @@ TEST_CASE(cse_test3)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto* mm1 = p1.get_main_module();
auto two = p1.add_literal(1); auto one = mm1->add_literal(1);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); auto two = mm1->add_literal(1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one); auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one);
p1.add_instruction(pass_op{}, sum3); auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto* mm2 = p2.get_main_module();
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one); auto one = mm2->add_literal(1);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1); auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, one);
p2.add_instruction(pass_op{}, sum3); auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -87,24 +94,26 @@ TEST_CASE(cse_test4) ...@@ -87,24 +94,26 @@ TEST_CASE(cse_test4)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto* mm1 = p1.get_main_module();
auto two = p1.add_literal(1); auto one = mm1->add_literal(1);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); auto two = mm1->add_literal(1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one); auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, one); auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one);
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum2, two); auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, one);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3); auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum2, two);
p1.add_instruction(pass_op{}, sum5); auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
mm1->add_instruction(pass_op{}, sum5);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto* mm2 = p2.get_main_module();
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one); auto one = mm2->add_literal(1);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, one); auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, one);
auto sum5 = p2.add_instruction(migraphx::op::add{}, sum3, sum3); auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, one);
p2.add_instruction(pass_op{}, sum5); auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum3, sum3);
mm2->add_instruction(pass_op{}, sum5);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -113,30 +122,32 @@ TEST_CASE(cse_test_literal) ...@@ -113,30 +122,32 @@ TEST_CASE(cse_test_literal)
{ {
migraphx::program p1; migraphx::program p1;
{ {
auto six1 = p1.add_literal(6); auto* mm1 = p1.get_main_module();
auto zero1 = p1.add_literal(0); auto six1 = mm1->add_literal(6);
auto six2 = p1.add_literal(6); auto zero1 = mm1->add_literal(0);
auto zero2 = p1.add_literal(0); auto six2 = mm1->add_literal(6);
auto six3 = p1.add_literal(6); auto zero2 = mm1->add_literal(0);
auto zero3 = p1.add_literal(0); auto six3 = mm1->add_literal(6);
auto zero3 = mm1->add_literal(0);
auto sum1 = p1.add_instruction(migraphx::op::add{}, six1, zero1); auto sum1 = mm1->add_instruction(migraphx::op::add{}, six1, zero1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, six2, zero2); auto sum2 = mm1->add_instruction(migraphx::op::add{}, six2, zero2);
auto sum3 = p1.add_instruction(migraphx::op::add{}, six3, zero3); auto sum3 = mm1->add_instruction(migraphx::op::add{}, six3, zero3);
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum3, sum4); auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum3, sum4);
p1.add_instruction(pass_op{}, sum5); mm1->add_instruction(pass_op{}, sum5);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto six = p2.add_literal(6); auto* mm2 = p2.get_main_module();
auto zero = p2.add_literal(0); auto six = mm2->add_literal(6);
auto sum1 = p2.add_instruction(migraphx::op::add{}, six, zero); auto zero = mm2->add_literal(0);
auto sum2 = p2.add_instruction(migraphx::op::add{}, sum1, sum1); auto sum1 = mm2->add_instruction(migraphx::op::add{}, six, zero);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum2 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3); auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
......
...@@ -47,7 +47,7 @@ struct concat_test_optimization ...@@ -47,7 +47,7 @@ struct concat_test_optimization
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, migraphx::run_passes(*p.get_main_module(),
{migraphx::eliminate_concat{concat_test_optimization{}}, {migraphx::eliminate_concat{concat_test_optimization{}},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
} }
...@@ -106,23 +106,27 @@ TEST_CASE(simple) ...@@ -106,23 +106,27 @@ TEST_CASE(simple)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1); auto* mm = p.get_main_module();
auto a2 = p.add_instruction(allocate{create_shape(1)}); auto a1 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2); auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = 0; std::size_t axis = 0;
auto a3 = p.add_instruction(allocate{create_shape(2)}); auto a3 = mm->add_instruction(allocate{create_shape(2)});
p.add_instruction(concat(axis), p1, p2, a3); mm->add_instruction(concat(axis), p1, p2, a3);
return p; return p;
}; };
auto create_control_program = [] { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2)});
auto l1 = p.add_instruction(load{create_shape(1), 0}, a1); auto* mm = p.get_main_module();
auto p1 = p.add_instruction(simple_op{}, l1); auto a1 = mm->add_instruction(allocate{create_shape(2)});
auto l2 = p.add_instruction(load{create_shape(1), 4}, a1); auto l1 = mm->add_instruction(load{create_shape(1), 0}, a1);
auto p2 = p.add_instruction(simple_op{}, l2); auto p1 = mm->add_instruction(simple_op{}, l1);
p.add_instruction(identity{}, a1, p1, p2); auto l2 = mm->add_instruction(load{create_shape(1), 4}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p1, p2);
return p; return p;
}; };
...@@ -137,13 +141,15 @@ TEST_CASE(negative_axis1) ...@@ -137,13 +141,15 @@ TEST_CASE(negative_axis1)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2, 2)});
auto p1 = p.add_instruction(simple_op{}, a1); auto* mm = p.get_main_module();
auto a2 = p.add_instruction(allocate{create_shape(2, 2)}); auto a1 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p2 = p.add_instruction(simple_op{}, a2); auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = -1; std::size_t axis = -1;
auto a3 = p.add_instruction(allocate{create_shape(4, 2)}); auto a3 = mm->add_instruction(allocate{create_shape(4, 2)});
p.add_instruction(concat(axis), p1, p2, a3); mm->add_instruction(concat(axis), p1, p2, a3);
return p; return p;
}; };
auto create_control_program = create_test_program; auto create_control_program = create_test_program;
...@@ -159,23 +165,27 @@ TEST_CASE(negative_axis2) ...@@ -159,23 +165,27 @@ TEST_CASE(negative_axis2)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2, 2)});
auto p1 = p.add_instruction(simple_op{}, a1); auto* mm = p.get_main_module();
auto a2 = p.add_instruction(allocate{create_shape(2, 2)}); auto a1 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p2 = p.add_instruction(simple_op{}, a2); auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = -2; std::size_t axis = -2;
auto a3 = p.add_instruction(allocate{create_shape(4, 2)}); auto a3 = mm->add_instruction(allocate{create_shape(4, 2)});
p.add_instruction(concat(axis), p1, p2, a3); mm->add_instruction(concat(axis), p1, p2, a3);
return p; return p;
}; };
auto create_control_program = [] { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(4, 2)});
auto l1 = p.add_instruction(load{create_shape(2, 2), 0}, a1); auto* mm = p.get_main_module();
auto p1 = p.add_instruction(simple_op{}, l1); auto a1 = mm->add_instruction(allocate{create_shape(4, 2)});
auto l2 = p.add_instruction(load{create_shape(2, 2), 16}, a1); auto l1 = mm->add_instruction(load{create_shape(2, 2), 0}, a1);
auto p2 = p.add_instruction(simple_op{}, l2); auto p1 = mm->add_instruction(simple_op{}, l1);
p.add_instruction(identity{}, a1, p1, p2); auto l2 = mm->add_instruction(load{create_shape(2, 2), 16}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p1, p2);
return p; return p;
}; };
...@@ -190,23 +200,27 @@ TEST_CASE(negative_axis3) ...@@ -190,23 +200,27 @@ TEST_CASE(negative_axis3)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1, 2, 2)});
auto p1 = p.add_instruction(simple_op{}, a1); auto* mm = p.get_main_module();
auto a2 = p.add_instruction(allocate{create_shape(1, 2, 2)}); auto a1 = mm->add_instruction(allocate{create_shape(1, 2, 2)});
auto p2 = p.add_instruction(simple_op{}, a2); auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1, 2, 2)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = -2; std::size_t axis = -2;
auto a3 = p.add_instruction(allocate{create_shape(1, 4, 2)}); auto a3 = mm->add_instruction(allocate{create_shape(1, 4, 2)});
p.add_instruction(concat(axis), p1, p2, a3); mm->add_instruction(concat(axis), p1, p2, a3);
return p; return p;
}; };
auto create_control_program = [] { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1, 4, 2)});
auto l1 = p.add_instruction(load{create_shape(1, 2, 2), 0}, a1); auto* mm = p.get_main_module();
auto p1 = p.add_instruction(simple_op{}, l1); auto a1 = mm->add_instruction(allocate{create_shape(1, 4, 2)});
auto l2 = p.add_instruction(load{create_shape(1, 2, 2), 16}, a1); auto l1 = mm->add_instruction(load{create_shape(1, 2, 2), 0}, a1);
auto p2 = p.add_instruction(simple_op{}, l2); auto p1 = mm->add_instruction(simple_op{}, l1);
p.add_instruction(identity{}, a1, p1, p2); auto l2 = mm->add_instruction(load{create_shape(1, 2, 2), 16}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p1, p2);
return p; return p;
}; };
...@@ -221,23 +235,27 @@ TEST_CASE(reversed) ...@@ -221,23 +235,27 @@ TEST_CASE(reversed)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1); auto* mm = p.get_main_module();
auto a2 = p.add_instruction(allocate{create_shape(1)}); auto a1 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2); auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = 0; std::size_t axis = 0;
auto a3 = p.add_instruction(allocate{create_shape(2)}); auto a3 = mm->add_instruction(allocate{create_shape(2)});
p.add_instruction(concat(axis), p2, p1, a3); mm->add_instruction(concat(axis), p2, p1, a3);
return p; return p;
}; };
auto create_control_program = [] { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2)});
auto l1 = p.add_instruction(load{create_shape(1), 4}, a1); auto* mm = p.get_main_module();
auto p1 = p.add_instruction(simple_op{}, l1); auto a1 = mm->add_instruction(allocate{create_shape(2)});
auto l2 = p.add_instruction(load{create_shape(1), 0}, a1); auto l1 = mm->add_instruction(load{create_shape(1), 4}, a1);
auto p2 = p.add_instruction(simple_op{}, l2); auto p1 = mm->add_instruction(simple_op{}, l1);
p.add_instruction(identity{}, a1, p2, p1); auto l2 = mm->add_instruction(load{create_shape(1), 0}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p2, p1);
return p; return p;
}; };
...@@ -251,38 +269,42 @@ TEST_CASE(reversed) ...@@ -251,38 +269,42 @@ TEST_CASE(reversed)
TEST_CASE(nested) TEST_CASE(nested)
{ {
auto concat_test_program = [](auto& p) { auto concat_test_program = [](auto& p) {
auto a1 = p.add_instruction(allocate{create_shape(1)}); auto* mm = p.get_main_module();
auto p1 = p.add_instruction(simple_op{}, a1); auto a1 = mm->add_instruction(allocate{create_shape(1)});
auto a2 = p.add_instruction(allocate{create_shape(1)}); auto p1 = mm->add_instruction(simple_op{}, a1);
auto p2 = p.add_instruction(simple_op{}, a2); auto a2 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = 0; std::size_t axis = 0;
auto a3 = p.add_instruction(allocate{create_shape(2)}); auto a3 = mm->add_instruction(allocate{create_shape(2)});
return p.add_instruction(concat(axis), p1, p2, a3); return mm->add_instruction(concat(axis), p1, p2, a3);
}; };
auto create_test_program = [&] { auto create_test_program = [&] {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto concat1 = concat_test_program(p); auto concat1 = concat_test_program(p);
auto concat2 = concat_test_program(p); auto concat2 = concat_test_program(p);
std::size_t axis = 0; std::size_t axis = 0;
auto a1 = p.add_instruction(allocate{create_shape(4)}); auto a1 = mm->add_instruction(allocate{create_shape(4)});
p.add_instruction(concat(axis), concat1, concat2, a1); mm->add_instruction(concat(axis), concat1, concat2, a1);
return p; return p;
}; };
auto concat_control_program = [](auto& p, auto a1) { auto concat_control_program = [](auto& p, auto a1) {
auto l1 = p.add_instruction(load{create_shape(1), 0}, a1); auto* mm = p.get_main_module();
auto p1 = p.add_instruction(simple_op{}, l1); auto l1 = mm->add_instruction(load{create_shape(1), 0}, a1);
auto l2 = p.add_instruction(load{create_shape(1), 4}, a1); auto p1 = mm->add_instruction(simple_op{}, l1);
auto p2 = p.add_instruction(simple_op{}, l2); auto l2 = mm->add_instruction(load{create_shape(1), 4}, a1);
return p.add_instruction(identity{}, a1, p1, p2); auto p2 = mm->add_instruction(simple_op{}, l2);
return mm->add_instruction(identity{}, a1, p1, p2);
}; };
auto create_control_program = [&] { auto create_control_program = [&] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(4)}); auto* mm = p.get_main_module();
auto l1 = p.add_instruction(load{create_shape(2), 0}, a1); auto a1 = mm->add_instruction(allocate{create_shape(4)});
auto l1 = mm->add_instruction(load{create_shape(2), 0}, a1);
auto concat1 = concat_control_program(p, l1); auto concat1 = concat_control_program(p, l1);
auto l2 = p.add_instruction(load{create_shape(2), 8}, a1); auto l2 = mm->add_instruction(load{create_shape(2), 8}, a1);
auto concat2 = concat_control_program(p, l2); auto concat2 = concat_control_program(p, l2);
p.add_instruction(identity{}, a1, concat1, concat2); mm->add_instruction(identity{}, a1, concat1, concat2);
return p; return p;
}; };
...@@ -297,35 +319,37 @@ TEST_CASE(basic) ...@@ -297,35 +319,37 @@ TEST_CASE(basic)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = auto* mm = p.get_main_module();
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}}); auto a1 = mm->add_instruction(
auto p1 = p.add_instruction(simple_op{}, a1); allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
auto a2 = auto p1 = mm->add_instruction(simple_op{}, a1);
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}}); auto a2 = mm->add_instruction(
auto p2 = p.add_instruction(simple_op{}, a2); allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
auto a3 = auto p2 = mm->add_instruction(simple_op{}, a2);
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}}); auto a3 = mm->add_instruction(
auto p3 = p.add_instruction(simple_op{}, a3); allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
auto p3 = mm->add_instruction(simple_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction( auto a4 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); mm->add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = [] { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = p.add_instruction( auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction( auto l1 = mm->add_instruction(
load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1}); load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
auto p1 = p.add_instruction(simple_op{}, l1); auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction( auto l2 = mm->add_instruction(
load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1}); load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
auto p2 = p.add_instruction(simple_op{}, l2); auto p2 = mm->add_instruction(simple_op{}, l2);
auto l3 = p.add_instruction( auto l3 = mm->add_instruction(
load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1}); load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
auto p3 = p.add_instruction(simple_op{}, l3); auto p3 = mm->add_instruction(simple_op{}, l3);
p.add_instruction(identity{}, {a1, p1, p2, p3}); mm->add_instruction(identity{}, {a1, p1, p2, p3});
return p; return p;
}; };
...@@ -340,36 +364,38 @@ TEST_CASE(wont_work) ...@@ -340,36 +364,38 @@ TEST_CASE(wont_work)
{ {
auto create_test_program = [] { auto create_test_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = auto* mm = p.get_main_module();
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); auto a1 = mm->add_instruction(
auto p1 = p.add_instruction(simple_op{}, a1); allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto a2 = auto p1 = mm->add_instruction(simple_op{}, a1);
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); auto a2 = mm->add_instruction(
auto p2 = p.add_instruction(simple_op{}, a2); allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto a3 = auto p2 = mm->add_instruction(simple_op{}, a2);
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); auto a3 = mm->add_instruction(
auto p3 = p.add_instruction(simple_op{}, a3); allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = mm->add_instruction(simple_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction( auto a4 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); mm->add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = [] { auto create_control_program = [] {
migraphx::program p; migraphx::program p;
auto a1 = auto* mm = p.get_main_module();
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); auto a1 = mm->add_instruction(
auto p1 = p.add_instruction(simple_op{}, a1); allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto a2 = auto p1 = mm->add_instruction(simple_op{}, a1);
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); auto a2 = mm->add_instruction(
auto p2 = p.add_instruction(simple_op{}, a2); allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto a3 = auto p2 = mm->add_instruction(simple_op{}, a2);
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); auto a3 = mm->add_instruction(
auto p3 = p.add_instruction(simple_op{}, a3); allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = mm->add_instruction(simple_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = p.add_instruction( auto a4 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); mm->add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
......
...@@ -12,16 +12,19 @@ ...@@ -12,16 +12,19 @@
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(*p.get_main_module(),
{migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
} }
TEST_CASE(standard_op) TEST_CASE(standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
p.add_instruction(pass_standard_op{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
...@@ -30,10 +33,12 @@ TEST_CASE(standard_op) ...@@ -30,10 +33,12 @@ TEST_CASE(standard_op)
TEST_CASE(standard_op_const) TEST_CASE(standard_op_const)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_literal(get_2x2());
p.add_instruction(pass_standard_op{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_standard_op{}, c);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
} }
...@@ -41,10 +46,12 @@ TEST_CASE(standard_op_const) ...@@ -41,10 +46,12 @@ TEST_CASE(standard_op_const)
TEST_CASE(non_standard_op) TEST_CASE(non_standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
p.add_instruction(pass_op{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
...@@ -53,10 +60,12 @@ TEST_CASE(non_standard_op) ...@@ -53,10 +60,12 @@ TEST_CASE(non_standard_op)
TEST_CASE(non_standard_op_const) TEST_CASE(non_standard_op_const)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_literal(get_2x2());
p.add_instruction(pass_op{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_op{}, c);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
} }
...@@ -64,11 +73,13 @@ TEST_CASE(non_standard_op_const) ...@@ -64,11 +73,13 @@ TEST_CASE(non_standard_op_const)
TEST_CASE(transpose_gemm) TEST_CASE(transpose_gemm)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_literal(get_2x2());
auto ic = p.add_instruction(migraphx::op::identity{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(migraphx::op::dot{}, ic, l); auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto ic = mm->add_instruction(migraphx::op::identity{}, c);
mm->add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
...@@ -77,11 +88,13 @@ TEST_CASE(transpose_gemm) ...@@ -77,11 +88,13 @@ TEST_CASE(transpose_gemm)
TEST_CASE(transpose_standard_op) TEST_CASE(transpose_standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto sn = p.add_instruction(migraphx::op::sin{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_standard_op{}, sn); auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c);
mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
...@@ -90,11 +103,13 @@ TEST_CASE(transpose_standard_op) ...@@ -90,11 +103,13 @@ TEST_CASE(transpose_standard_op)
TEST_CASE(transpose_standard_op_const) TEST_CASE(transpose_standard_op_const)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_literal(get_2x2());
auto sn = p.add_instruction(migraphx::op::sin{}, c); auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_standard_op{}, sn); auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c);
mm->add_instruction(pass_standard_op{}, sn);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 3); EXPECT(std::distance(p.begin(), p.end()) == 3);
} }
...@@ -102,11 +117,13 @@ TEST_CASE(transpose_standard_op_const) ...@@ -102,11 +117,13 @@ TEST_CASE(transpose_standard_op_const)
TEST_CASE(no_packed_unary_op) TEST_CASE(no_packed_unary_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto l = mm->add_literal(get_2x2());
auto sn = p.add_instruction(migraphx::op::sin{}, c); auto t = mm->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
p.add_instruction(pass_standard_op{}, sn); auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c);
mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1); EXPECT(std::distance(p.begin(), p.end()) == count - 1);
...@@ -115,10 +132,12 @@ TEST_CASE(no_packed_unary_op) ...@@ -115,10 +132,12 @@ TEST_CASE(no_packed_unary_op)
TEST_CASE(non_standard_return_input) TEST_CASE(non_standard_return_input)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2());
auto tl = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto* mm = p.get_main_module();
auto c = p.add_instruction(migraphx::op::contiguous{}, tl); auto l = mm->add_literal(get_2x2());
p.add_return({c}); auto tl = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, tl);
mm->add_return({c});
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
......
...@@ -6,17 +6,22 @@ ...@@ -6,17 +6,22 @@
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::eliminate_identity{}}); } void run_pass(migraphx::program& p)
{
migraphx::run_passes(*p.get_main_module(), {migraphx::eliminate_identity{}});
}
TEST_CASE(simple_test) TEST_CASE(simple_test)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto* mm = p.get_main_module();
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2); auto one = mm->add_literal(1);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two); auto one_identity = mm->add_instruction(migraphx::op::identity{}, one);
p.add_instruction(sum_op{}, one_identity, two_identity); auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::op::identity{}, two);
mm->add_instruction(sum_op{}, one_identity, two_identity);
run_pass(p); run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
...@@ -29,10 +34,12 @@ TEST_CASE(simple_test_end) ...@@ -29,10 +34,12 @@ TEST_CASE(simple_test_end)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto* mm = p.get_main_module();
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two); auto one = mm->add_literal(1);
p.add_instruction(migraphx::op::identity{}, ans); auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::op::identity{}, ans);
run_pass(p); run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
...@@ -45,12 +52,14 @@ TEST_CASE(simple_test_end_dependency) ...@@ -45,12 +52,14 @@ TEST_CASE(simple_test_end_dependency)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1.0); auto* mm = p.get_main_module();
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0); auto one = mm->add_literal(1.0);
auto ans = p.add_instruction(sum_op{}, one, two); auto two = mm->add_literal(2.0);
p.add_instruction(sum_op{}, ans, three); auto three = mm->add_literal(3.0);
p.add_instruction(migraphx::op::identity{}, ans); auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, ans, three);
mm->add_instruction(migraphx::op::identity{}, ans);
run_pass(p); run_pass(p);
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(*p.get_main_module(),
{migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
} }
migraphx::instruction_ref migraphx::instruction_ref
...@@ -16,10 +17,10 @@ create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::progr ...@@ -16,10 +17,10 @@ create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::progr
{ {
size_t f[2] = {1, 1}; size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]); std::vector<int32_t> weights(channels * f[0] * f[1]);
auto* mm = p.get_main_module();
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
return p.add_instruction(migraphx::op::im2col{}, l_img, l_weights); return mm->add_instruction(migraphx::op::im2col{}, l_img, l_weights);
} }
migraphx::instruction_ref migraphx::instruction_ref
...@@ -30,30 +31,30 @@ create_conv(migraphx::instruction_ref& l_img, ...@@ -30,30 +31,30 @@ create_conv(migraphx::instruction_ref& l_img,
{ {
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3); std::vector<int32_t> weights(4 * channels * 3 * 3);
auto* mm = p.get_main_module();
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = padding_mode; op.padding_mode = padding_mode;
return p.add_instruction(op, l_img, l_weights); return mm->add_instruction(op, l_img, l_weights);
} }
TEST_CASE(rewrite_test) TEST_CASE(rewrite_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
size_t img_dim[2] = {2, 2}; size_t img_dim[2] = {2, 2};
size_t channels = 1; size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]); std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input}); auto l_img = mm->add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img); auto padded_img = mm->add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
auto l0 = create_im2col(padded_img, channels, p); auto l0 = create_im2col(padded_img, channels, p);
auto l1 = create_conv(padded_img, channels, p); auto l1 = create_conv(padded_img, channels, p);
auto l2 = p.add_instruction(migraphx::op::pooling{"max"}, padded_img); auto l2 = mm->add_instruction(migraphx::op::pooling{"max"}, padded_img);
p.add_instruction(migraphx::op::identity{}, l0, l1, l2); mm->add_instruction(migraphx::op::identity{}, l0, l1, l2);
run_pass(p); run_pass(p);
EXPECT(std::none_of( EXPECT(std::none_of(
...@@ -63,14 +64,16 @@ TEST_CASE(rewrite_test) ...@@ -63,14 +64,16 @@ TEST_CASE(rewrite_test)
TEST_CASE(rewrite_test_asymmetric) TEST_CASE(rewrite_test_asymmetric)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
size_t img_dim[2] = {2, 2}; size_t img_dim[2] = {2, 2};
size_t channels = 1; size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]); std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input}); auto l_img = mm->add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img); auto padded_img = mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img);
create_im2col(padded_img, channels, p); create_im2col(padded_img, channels, p);
......
...@@ -71,7 +71,7 @@ struct reverse_pass ...@@ -71,7 +71,7 @@ struct reverse_pass
{ {
std::string name() const { return "reverse_pass"; } std::string name() const { return "reverse_pass"; }
void apply(migraphx::program& p) const { std::reverse(p.begin(), p.end()); } void apply(migraphx::module& p) const { std::reverse(p.begin(), p.end()); }
}; };
struct reverse_target struct reverse_target
...@@ -89,7 +89,7 @@ struct invert_pass ...@@ -89,7 +89,7 @@ struct invert_pass
{ {
std::string name() const { return "invert_pass"; } std::string name() const { return "invert_pass"; }
void apply(migraphx::program& p) const void apply(migraphx::module& p) const
{ {
for(auto ins : migraphx::iterator_for(p)) for(auto ins : migraphx::iterator_for(p))
{ {
...@@ -130,10 +130,10 @@ struct double_invert_target ...@@ -130,10 +130,10 @@ struct double_invert_target
TEST_CASE(literal_test1) TEST_CASE(literal_test1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -142,11 +142,11 @@ TEST_CASE(literal_test1) ...@@ -142,11 +142,11 @@ TEST_CASE(literal_test1)
TEST_CASE(literal_test2) TEST_CASE(literal_test2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, sum1, two); mm->add_instruction(sum_op{}, sum1, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5}); EXPECT(result == migraphx::literal{5});
...@@ -156,10 +156,10 @@ TEST_CASE(literal_test2) ...@@ -156,10 +156,10 @@ TEST_CASE(literal_test2)
TEST_CASE(print_test) TEST_CASE(print_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto x = p.add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, x, two); mm->add_instruction(sum_op{}, x, two);
std::stringstream ss; std::stringstream ss;
ss << p; ss << p;
...@@ -170,11 +170,11 @@ TEST_CASE(print_test) ...@@ -170,11 +170,11 @@ TEST_CASE(print_test)
TEST_CASE(param_test) TEST_CASE(param_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type}); mm->add_instruction(sum_op{}, x, y);
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()}, auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}}) {"y", migraphx::literal{2}.get_argument()}})
.back(); .back();
...@@ -185,11 +185,11 @@ TEST_CASE(param_test) ...@@ -185,11 +185,11 @@ TEST_CASE(param_test)
TEST_CASE(param_error_test) TEST_CASE(param_error_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type}); mm->add_instruction(sum_op{}, x, y);
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
[&] { [&] {
p.eval({{"x", migraphx::literal{1}.get_argument()}}); p.eval({{"x", migraphx::literal{1}.get_argument()}});
...@@ -200,11 +200,11 @@ TEST_CASE(param_error_test) ...@@ -200,11 +200,11 @@ TEST_CASE(param_error_test)
TEST_CASE(param_error_shape_test) TEST_CASE(param_error_shape_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 1}}); mm->add_instruction(sum_op{}, x, y);
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
[&] { [&] {
p.eval({ p.eval({
...@@ -218,10 +218,11 @@ TEST_CASE(param_error_shape_test) ...@@ -218,10 +218,11 @@ TEST_CASE(param_error_shape_test)
TEST_CASE(get_param1) TEST_CASE(get_param1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = mm->add_parameter("y", s);
p.add_instruction(sum_op{}, x, y); mm->add_instruction(sum_op{}, x, y);
EXPECT(bool{p.get_parameter("x") == x}); EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y}); EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == p.end()}); EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
...@@ -230,19 +231,21 @@ TEST_CASE(get_param1) ...@@ -230,19 +231,21 @@ TEST_CASE(get_param1)
TEST_CASE(get_param2) TEST_CASE(get_param2)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto* mm = p.get_main_module();
auto two = p.add_literal(2); auto one = mm->add_literal(1);
p.add_instruction(sum_op{}, one, two); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
EXPECT(bool{p.get_parameter("nonexistent") == p.end()}); EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
} }
TEST_CASE(get_param_shapes) TEST_CASE(get_param_shapes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = mm->add_parameter("y", s);
p.add_instruction(sum_op{}, x, y); mm->add_instruction(sum_op{}, x, y);
auto m = p.get_parameter_shapes(); auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0); EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s); EXPECT(m.at("x") == s);
...@@ -252,11 +255,11 @@ TEST_CASE(get_param_shapes) ...@@ -252,11 +255,11 @@ TEST_CASE(get_param_shapes)
TEST_CASE(replace_test) TEST_CASE(replace_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
p.replace_instruction(sum, minus_op{}, two, one); mm->replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -267,12 +270,12 @@ TEST_CASE(replace_test) ...@@ -267,12 +270,12 @@ TEST_CASE(replace_test)
TEST_CASE(replace_ins_test) TEST_CASE(replace_ins_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(minus_op{}, two, one);
p.replace_instruction(sum, minus); mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -283,13 +286,13 @@ TEST_CASE(replace_ins_test) ...@@ -283,13 +286,13 @@ TEST_CASE(replace_ins_test)
TEST_CASE(replace_ins_test2) TEST_CASE(replace_ins_test2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(minus_op{}, two, one);
p.add_instruction(pass_op{}, minus); mm->add_instruction(pass_op{}, minus);
p.replace_instruction(two, sum); mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -300,10 +303,10 @@ TEST_CASE(replace_ins_test2) ...@@ -300,10 +303,10 @@ TEST_CASE(replace_ins_test2)
TEST_CASE(replace_op_test) TEST_CASE(replace_op_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum = p.add_instruction(sum_op{}, two, one); auto sum = mm->add_instruction(sum_op{}, two, one);
sum->replace(minus_op{}); sum->replace(minus_op{});
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
...@@ -315,24 +318,24 @@ TEST_CASE(replace_op_test) ...@@ -315,24 +318,24 @@ TEST_CASE(replace_op_test)
TEST_CASE(replace_op_recompute_shape_throw) TEST_CASE(replace_op_recompute_shape_throw)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); })); EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
} }
TEST_CASE(insert_replace_test) TEST_CASE(insert_replace_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, sum1, two); mm->add_instruction(sum_op{}, sum1, two);
auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two); auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two);
p.replace_instruction(sum1, minus_op{}, sum0, two); mm->replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -343,12 +346,12 @@ TEST_CASE(insert_replace_test) ...@@ -343,12 +346,12 @@ TEST_CASE(insert_replace_test)
TEST_CASE(remove_test1) TEST_CASE(remove_test1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
auto removed = p.add_instruction(minus_op{}, sum, one); auto removed = mm->add_instruction(minus_op{}, sum, one);
p.remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -359,12 +362,12 @@ TEST_CASE(remove_test1) ...@@ -359,12 +362,12 @@ TEST_CASE(remove_test1)
TEST_CASE(remove_test2) TEST_CASE(remove_test2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
auto removed = p.add_instruction(minus_op{}, two, one); auto removed = mm->add_instruction(minus_op{}, two, one);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -375,10 +378,10 @@ TEST_CASE(remove_test2) ...@@ -375,10 +378,10 @@ TEST_CASE(remove_test2)
TEST_CASE(target_test) TEST_CASE(target_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.compile(id_target{}); p.compile(id_target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -388,10 +391,10 @@ TEST_CASE(target_test) ...@@ -388,10 +391,10 @@ TEST_CASE(target_test)
TEST_CASE(invert_target_test) TEST_CASE(invert_target_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, two, one); mm->add_instruction(sum_op{}, two, one);
p.compile(invert_target{}); p.compile(invert_target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
...@@ -401,10 +404,10 @@ TEST_CASE(invert_target_test) ...@@ -401,10 +404,10 @@ TEST_CASE(invert_target_test)
TEST_CASE(double_invert_target_test) TEST_CASE(double_invert_target_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, two, one); mm->add_instruction(sum_op{}, two, one);
p.compile(double_invert_target{}); p.compile(double_invert_target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -414,10 +417,10 @@ TEST_CASE(double_invert_target_test) ...@@ -414,10 +417,10 @@ TEST_CASE(double_invert_target_test)
TEST_CASE(reverse_target_test) TEST_CASE(reverse_target_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
EXPECT(test::throws<migraphx::exception>([&] { p.compile(reverse_target{}); })); EXPECT(test::throws<migraphx::exception>([&] { p.compile(reverse_target{}); }));
} }
...@@ -426,11 +429,12 @@ TEST_CASE(reverse_target_test) ...@@ -426,11 +429,12 @@ TEST_CASE(reverse_target_test)
TEST_CASE(eval_context1) TEST_CASE(eval_context1)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
id_target t{}; id_target t{};
EXPECT(is_shared(t.ctx, t.get_context())); EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back(); p.eval({}).back();
...@@ -440,11 +444,12 @@ TEST_CASE(eval_context1) ...@@ -440,11 +444,12 @@ TEST_CASE(eval_context1)
TEST_CASE(eval_context2) TEST_CASE(eval_context2)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
id_target t{}; id_target t{};
EXPECT(is_shared(t.ctx, t.get_context())); EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(id_ctx_op{}, one, two); mm->add_instruction(id_ctx_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back(); p.eval({}).back();
...@@ -455,11 +460,12 @@ TEST_CASE(eval_context2) ...@@ -455,11 +460,12 @@ TEST_CASE(eval_context2)
TEST_CASE(eval_context3) TEST_CASE(eval_context3)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
id_target t{}; id_target t{};
EXPECT(is_shared(t.ctx, t.get_context())); EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1); auto one = mm->add_literal(1);
auto two = p.add_literal(2); auto two = mm->add_literal(2);
p.add_instruction(id_ctx_final_op{}, one, two); mm->add_instruction(id_ctx_final_op{}, one, two);
p.compile(t); p.compile(t);
// Finalizer will modify the context // Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
...@@ -495,11 +501,13 @@ std::string capture_output(F f) ...@@ -495,11 +501,13 @@ std::string capture_output(F f)
TEST_CASE(debug_print_test) TEST_CASE(debug_print_test)
{ {
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
std::vector<migraphx::instruction_ref> onev = {one}; std::vector<migraphx::instruction_ref> onev = {one};
migraphx::program p2; migraphx::program p2;
auto one2 = p2.add_literal(1); auto* mm2 = p2.get_main_module();
auto one2 = mm2->add_literal(1);
auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); })); auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); }));
auto ins_out = migraphx::trim(capture_output([&] { p.debug_print(one); })); auto ins_out = migraphx::trim(capture_output([&] { p.debug_print(one); }));
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
void run_lowering(migraphx::program& p) void run_lowering(migraphx::program& p)
{ {
auto ctx = migraphx::gpu::context{}; auto ctx = migraphx::gpu::context{};
migraphx::run_passes(p, migraphx::run_passes(*p.get_main_module(),
{migraphx::auto_contiguous{}, {migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false}, migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
...@@ -30,12 +30,13 @@ TEST_CASE(tanh_shape) ...@@ -30,12 +30,13 @@ TEST_CASE(tanh_shape)
{ {
auto create_program = [] { auto create_program = [] {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); auto tx = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx); auto txh = mm->add_instruction(migraphx::op::tanh{}, tx);
auto sum = p.add_instruction(migraphx::op::add{}, txh, txh); auto sum = mm->add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum); mm->add_instruction(migraphx::op::contiguous{}, sum);
return p; return p;
}; };
...@@ -59,7 +60,7 @@ TEST_CASE(tanh_shape) ...@@ -59,7 +60,7 @@ TEST_CASE(tanh_shape)
} }
EXPECT(p1 != p2); EXPECT(p1 != p2);
migraphx::run_passes(p2, migraphx::run_passes(*p2.get_main_module(),
{migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}}); {migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
migraphx::program create_gelu() migraphx::program create_gelu()
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data0 = {0.044715}; std::vector<float> data0 = {0.044715};
std::vector<float> data1 = {0.797885}; std::vector<float> data1 = {0.797885};
std::vector<float> data2 = {3}; std::vector<float> data2 = {3};
...@@ -20,25 +21,25 @@ migraphx::program create_gelu() ...@@ -20,25 +21,25 @@ migraphx::program create_gelu()
std::vector<size_t> x_dims{1, 1, 5}; std::vector<size_t> x_dims{1, 1, 5};
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = p.add_literal(migraphx::literal{s0, data0}); auto const_val = mm->add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = p.add_literal(migraphx::literal{s0, data1}); auto sqrt_2_pi = mm->add_literal(migraphx::literal{s0, data1});
auto three_val = p.add_literal(migraphx::literal{s0, data2}); auto three_val = mm->add_literal(migraphx::literal{s0, data2});
auto half_val = p.add_literal(migraphx::literal{s0, data3}); auto half_val = mm->add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = p.add_instruction(migraphx::op::multibroadcast{x_dims}, three_val); auto mbcast_3 = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = p.add_instruction(migraphx::op::pow{}, x, mbcast_3); auto pow_op = mm->add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = p.add_instruction(migraphx::op::multibroadcast{x_dims}, const_val); auto mbcast_const = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = p.add_instruction(migraphx::op::mul{}, mbcast_const, pow_op); auto mul_const = mm->add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = p.add_instruction(migraphx::op::add{}, x, mul_const); auto add_x = mm->add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = p.add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi); auto mbcast_sqrt_2_pi = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = p.add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x); auto mul_add_x = mm->add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = p.add_instruction(migraphx::op::tanh{}, mul_add_x); auto tanh_op = mm->add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = p.add_instruction(migraphx::op::multibroadcast{x_dims}, half_val); auto mbcast_half = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = p.add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op); auto mul_half = mm->add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = p.add_instruction(migraphx::op::add{}, mul_half, mbcast_half); auto add_mul_half = mm->add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = p.add_instruction(migraphx::op::mul{}, x, add_mul_half); auto mul_x = mm->add_instruction(migraphx::op::mul{}, x, add_mul_half);
p.add_return({mul_x}); mm->add_return({mul_x});
return p; return p;
} }
......
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
void gpu_literal_test() void gpu_literal_test()
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit); mm->add_literal(lit);
p.compile(migraphx::gpu::target{}); p.compile(migraphx::gpu::target{});
auto scratch = p.get_parameter("scratch"); auto scratch = p.get_parameter("scratch");
if(scratch == p.end()) if(scratch == p.end())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment