Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
......@@ -8,7 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void sync_device::apply(program& p) const
void sync_device::apply(module& p) const
{
auto last = std::prev(p.end());
if(last->name() == "@return")
......
......@@ -2,6 +2,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
......@@ -10,7 +11,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(program& p) const
void write_literals::apply(module& p) const
{
assert(ctx != nullptr);
std::size_t n = 0;
......
......@@ -11,7 +11,7 @@ namespace ref {
struct lowering
{
std::string name() const { return "ref::lowering"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace ref
......
......@@ -882,7 +882,7 @@ MIGRAPHX_REGISTER_OP(ref_rnn_var_sl_last_output)
struct ref_apply
{
program* prog;
module* modl;
std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};
template <class T>
......@@ -922,7 +922,7 @@ struct ref_apply
void apply()
{
init();
for(auto it : iterator_for(*prog))
for(auto it : iterator_for(*modl))
{
if(it->name() == "pooling")
{
......@@ -941,33 +941,33 @@ struct ref_apply
void apply_ref_op(instruction_ref ins) const
{
prog->replace_instruction(ins, ref_op{ins->get_operator()}, ins->inputs());
modl->replace_instruction(ins, ref_op{ins->get_operator()}, ins->inputs());
}
template <class T>
void apply_simple_op(instruction_ref ins)
{
prog->replace_instruction(ins, T{}, ins->inputs());
modl->replace_instruction(ins, T{}, ins->inputs());
}
template <class T, class Op>
void apply_extend_op(instruction_ref ins)
{
auto&& op = any_cast<Op>(ins->get_operator());
prog->replace_instruction(ins, T{op}, ins->inputs());
modl->replace_instruction(ins, T{op}, ins->inputs());
}
void apply_pooling(instruction_ref ins) const
{
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "max")
prog->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
modl->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average")
prog->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
modl->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
}
};
void lowering::apply(program& p) const { ref_apply{&p}.apply(); }
void lowering::apply(module& m) const { ref_apply{&m}.apply(); }
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -33,6 +33,7 @@ struct tf_parser
std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
module* mm = prog.get_main_module();
bool is_nhwc = true;
unsigned int batch_size = 1;
......@@ -43,33 +44,33 @@ struct tf_parser
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins)
instruction_ref to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
instruction_ref to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
instruction_ref to_kcxy(instruction_ref ins) const
{
if(should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
instruction_ref make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
return mm->add_instruction(op::contiguous{}, ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
......@@ -266,7 +267,7 @@ struct tf_parser
// {
// if(is_nhwc)
// {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
// }
return add_broadcastable_binary_op(args[0], args[1], x);
......@@ -308,13 +309,13 @@ struct tf_parser
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1)));
auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1);
return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
}
else
{
return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
}
}
......@@ -323,7 +324,7 @@ struct tf_parser
{
add_op(name,
[this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
return mm->add_instruction(x, args);
},
transpose);
}
......@@ -334,12 +335,13 @@ struct tf_parser
{
int64_t axis = 0;
axis = args[1]->eval().at<int64_t>();
auto ins = prog.add_instruction(Op{axis}, args.front());
return prog.add_instruction(op::squeeze{{axis}}, ins);
auto ins = mm->add_instruction(Op{axis}, args.front());
return mm->add_instruction(op::squeeze{{axis}}, ins);
}
instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_batchnorm(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
......@@ -349,46 +351,49 @@ struct tf_parser
epsilon = attributes.at("epsilon").f();
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
return mm->add_instruction(op, std::move(args));
}
instruction_ref
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0);
auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return mm->add_instruction(op::add{}, args[0], l0);
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_cast(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return prog.add_instruction(op::convert{type}, std::move(args));
return mm->add_instruction(op::convert{type}, std::move(args));
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_concat(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
// get index for axis within args
size_t axis_idx = attributes.at("N").i();
int64_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(
return mm->add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&)
const std::vector<instruction_ref>&) const
{
literal v = parse_tensor(attributes.at("value").tensor());
return prog.add_literal(v);
return mm->add_literal(v);
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_conv(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
op::convolution op;
if(contains(attributes, "strides"))
......@@ -436,7 +441,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding}, l0);
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
}
else
{
......@@ -464,12 +469,12 @@ struct tf_parser
op.padding[1] = padding[1];
}
}
return prog.add_instruction(op, {l0, to_kcxy(args[1])});
return mm->add_instruction(op, {l0, to_kcxy(args[1])});
}
instruction_ref parse_depthwiseconv(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1];
......@@ -522,7 +527,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(migraphx::op::pad{padding}, l0);
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
}
else
{
......@@ -548,13 +553,14 @@ struct tf_parser
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto new_weights =
prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
return prog.add_instruction(op, {l0, new_weights});
return mm->add_instruction(op, {l0, new_weights});
}
instruction_ref
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
instruction_ref parse_expanddims(const std::string&,
const attribute_map&,
std::vector<instruction_ref> args) const
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
......@@ -569,19 +575,20 @@ struct tf_parser
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return prog.add_instruction(op::reshape{new_dims}, args[0]);
return mm->add_instruction(op::reshape{new_dims}, args[0]);
}
instruction_ref
parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
int axis = args[2]->eval().at<int32_t>();
op::gather op{axis};
return prog.add_instruction(op, {args[0], args[1]});
return mm->add_instruction(op, {args[0], args[1]});
}
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_matmul(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
bool transa = false;
bool transb = false;
......@@ -609,31 +616,33 @@ struct tf_parser
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
return prog.add_instruction(op::dot{}, l1, l2);
return mm->add_instruction(op::dot{}, l1, l2);
}
instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
bool keep_dims = attributes.at("keep_dims").b();
auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
if(keep_dims)
{
return prog.add_instruction(op::reduce_mean{axes}, args[0]);
return mm->add_instruction(op::reduce_mean{axes}, args[0]);
}
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]);
return prog.add_instruction(op::squeeze{axes}, ins);
auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]);
return mm->add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_onehot(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args) const
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
......@@ -652,15 +661,15 @@ struct tf_parser
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
auto l0 = mm->add_literal({s, depth_input});
return mm->add_instruction(op::gather{0}, {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
......@@ -678,12 +687,12 @@ struct tf_parser
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(prog.add_instruction(op::concat{axis}, unsqueezed_args));
[&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args));
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
size_t ndims = args.front()->get_shape().lens().size();
......@@ -706,12 +715,12 @@ struct tf_parser
pads[i + ndims] = pad_per_dim[i].second;
}
op.pads = pads;
return prog.add_instruction(op, args.front());
return mm->add_instruction(op, args.front());
}
instruction_ref parse_pooling(const std::string& name,
attribute_map attributes,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
op::pooling op{starts_with(name, "Max") ? "max" : "average"};
......@@ -754,7 +763,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = prog.add_instruction(
l0 = mm->add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
}
else
......@@ -764,47 +773,47 @@ struct tf_parser
}
}
}
return prog.add_instruction(op, l0);
return mm->add_instruction(op, l0);
}
instruction_ref
parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = prog.add_literal(0.0f);
auto max_val = prog.add_literal(6.0f);
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = prog.add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = prog.add_instruction(op::multibroadcast{input_lens}, max_val);
return prog.add_instruction(op::clip{}, args.front(), min_val, max_val);
min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val);
return mm->add_instruction(op::clip{}, args.front(), min_val, max_val);
}
instruction_ref
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
op::reshape op;
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, make_contiguous(args[0]));
return mm->add_instruction(op, make_contiguous(args[0]));
}
// Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return prog.add_literal(migraphx::literal{s, vec_shape});
return mm->add_literal(migraphx::literal{s, vec_shape});
}
instruction_ref
parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
......@@ -823,7 +832,7 @@ struct tf_parser
else
op.ends[i] = starts[i] + size[i];
}
return prog.add_instruction(op, make_contiguous(args[0]));
return mm->add_instruction(op, make_contiguous(args[0]));
}
// template to facilitate the logsoftmax later
......@@ -843,12 +852,12 @@ struct tf_parser
axis += num_dims;
}
return prog.add_instruction(Op{axis}, make_contiguous(args[0]));
return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
}
std::vector<instruction_ref> parse_split(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
bool vector_as_input = args.size() == 3;
int num_outputs = 1;
......@@ -874,7 +883,7 @@ struct tf_parser
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{prog.add_instruction(op::identity{}, input_arg)};
return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
......@@ -919,14 +928,14 @@ struct tf_parser
op.starts[axis] = slice_pos[i];
op.ends[axis] = slice_pos[i + 1];
result.push_back(prog.add_instruction(op, input_arg));
result.push_back(mm->add_instruction(op, input_arg));
}
return result;
}
instruction_ref parse_squeeze(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
std::vector<instruction_ref> args) const
{
op::squeeze op;
auto input_dims = args[0]->get_shape().lens();
......@@ -943,7 +952,7 @@ struct tf_parser
}
}
}
return prog.add_instruction(op, make_contiguous(args[0]));
return mm->add_instruction(op, make_contiguous(args[0]));
}
instruction_ref parse_stridedslice(const std::string&,
......@@ -991,7 +1000,7 @@ struct tf_parser
}
}
auto l1 = prog.add_instruction(op, l0);
auto l1 = mm->add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
......@@ -1002,17 +1011,18 @@ struct tf_parser
squeeze_axes.push_back(i);
}
return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
return mm->add_instruction(op::squeeze{squeeze_axes}, l1);
}
instruction_ref
parse_transpose(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
instruction_ref parse_transpose(const std::string&,
const attribute_map&,
std::vector<instruction_ref> args) const
{
auto perm = args[1]->eval().get<int32_t>().to_vector();
op::transpose op;
op.dims = std::vector<int64_t>(perm.begin(), perm.end());
return prog.add_instruction(op, args.front());
return mm->add_instruction(op, args.front());
}
void parse_graph(const tensorflow::GraphDef& graph)
......@@ -1032,7 +1042,7 @@ struct tf_parser
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(prog.add_parameter(name, s));
instructions[name] = to_nhwc(mm->add_parameter(name, s));
}
for(auto&& p : nodes)
{
......@@ -1083,7 +1093,7 @@ struct tf_parser
std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0)
{
result.push_back(prog.add_instruction(op::unknown{node.op()}, args));
result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
}
else
{
......@@ -1405,7 +1415,7 @@ program parse_tf(const std::string& name, tf_options options)
#else
parser.parse_from(input);
#endif
parser.to_nchw(std::prev(parser.prog.end()));
parser.to_nchw(std::prev(parser.mm->end()));
return std::move(parser.prog);
}
......
......@@ -75,26 +75,27 @@ struct test_stream_model
struct program_model
{
migraphx::program p;
migraphx::module* mm = p.get_main_module();
std::unordered_map<migraphx::instruction_ref, std::size_t> ins2stream{};
std::size_t max_stream = 0;
template <class... Ts>
migraphx::instruction_ref add_literal(Ts... xs)
{
return p.add_literal(xs...);
return mm->add_literal(xs...);
}
template <class... Ts>
migraphx::instruction_ref add_instruction(Ts... xs)
{
return p.add_instruction(xs...);
return mm->add_instruction(xs...);
}
template <class... Ts>
migraphx::instruction_ref add_instruction_stream(std::size_t n, Ts... xs)
{
max_stream = std::max(max_stream, n);
auto ins = p.add_instruction(xs...);
auto ins = mm->add_instruction(xs...);
ins2stream[ins] = n;
return ins;
}
......@@ -102,14 +103,14 @@ struct program_model
template <class... Ts>
migraphx::instruction_ref add_return(Ts... xs)
{
return p.add_return({xs...});
return mm->add_return({xs...});
}
template <class... Ts>
migraphx::instruction_ref add_return_stream(std::size_t n, Ts... xs)
{
max_stream = std::max(max_stream, n);
auto ins = p.add_return({xs...});
auto ins = mm->add_return({xs...});
ins2stream[ins] = n;
return ins;
}
......@@ -118,7 +119,7 @@ struct program_model
std::vector<migraphx::stream_race> analyze() const
{
return migraphx::analyze_streams(p, get_stream_model());
return migraphx::analyze_streams(*p.get_main_module(), get_stream_model());
}
void debug_print() const { p.debug_print(); }
......
......@@ -6,13 +6,18 @@
#include <basic_ops.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::auto_contiguous{}}); }
void run_pass(migraphx::program& p)
{
migraphx::run_passes(*p.get_main_module(), {migraphx::auto_contiguous{}});
}
// TODO: Add this test case
void literal_broadcast()
{
migraphx::program p;
p.add_literal(get_2_broadcasted());
auto* mm = p.get_main_module();
mm->add_literal(get_2_broadcasted());
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted());
run_pass(p);
......@@ -23,7 +28,9 @@ void literal_broadcast()
TEST_CASE(literal_transpose)
{
migraphx::program p;
p.add_literal(get_2x2_transposed());
auto* mm = p.get_main_module();
mm->add_literal(get_2x2_transposed());
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -34,11 +41,13 @@ TEST_CASE(literal_transpose)
TEST_CASE(after_literal_transpose)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t);
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
mm->add_instruction(pass_op{}, t);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -49,12 +58,14 @@ TEST_CASE(after_literal_transpose)
TEST_CASE(after_literal_broadcast)
{
migraphx::program p;
auto l1 = p.add_literal(get_2x2());
auto l2 = p.add_literal(get_2());
auto* mm = p.get_main_module();
auto l1 = mm->add_literal(get_2x2());
auto l2 = mm->add_literal(get_2());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b);
auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
mm->add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted());
run_pass(p);
......@@ -65,11 +76,13 @@ TEST_CASE(after_literal_broadcast)
TEST_CASE(after_param_transpose)
{
migraphx::program p;
auto l = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto* mm = p.get_main_module();
auto l = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t);
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
mm->add_instruction(pass_op{}, t);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -80,12 +93,14 @@ TEST_CASE(after_param_transpose)
TEST_CASE(after_param_broadcast)
{
migraphx::program p;
auto l1 = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
auto* mm = p.get_main_module();
auto l1 = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b);
auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
mm->add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted());
run_pass(p);
......
......@@ -52,42 +52,52 @@ struct test_context
TEST_CASE(literal_test)
{
migraphx::program p;
auto lit = p.add_literal(1);
auto* mm = p.get_main_module();
auto lit = mm->add_literal(1);
CHECK(lit->eval() == migraphx::literal{1});
}
TEST_CASE(param_test)
{
migraphx::program p;
auto lit = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto* mm = p.get_main_module();
auto lit = mm->add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
CHECK(lit->eval().empty());
}
TEST_CASE(op_test1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_cf_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_cf_op{}, one, two);
CHECK(sum->eval() == migraphx::literal{3});
}
TEST_CASE(op_test2)
{
migraphx::program p;
auto x = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_cf_op{}, x, two);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}});
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_cf_op{}, x, two);
CHECK(sum->eval().empty());
}
TEST_CASE(op_test3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_cf_op{}, sum1, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_cf_op{}, sum1, two);
CHECK(sum2->eval().empty());
}
......
......@@ -8,16 +8,16 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
}
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
......@@ -29,11 +29,11 @@ TEST_CASE(simple_test)
TEST_CASE(simple_test_nop)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(nop{});
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
......@@ -45,12 +45,12 @@ TEST_CASE(simple_test_nop)
TEST_CASE(simple_test_nop2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(nop{});
p.add_instruction(sum_op{}, one, two);
p.add_instruction(nop{});
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(nop{});
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({}).back();
......@@ -61,11 +61,11 @@ TEST_CASE(simple_test_nop2)
TEST_CASE(duplicate_test1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
......@@ -77,12 +77,12 @@ TEST_CASE(duplicate_test1)
TEST_CASE(duplicate_test2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, one, two);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 2));
......@@ -94,14 +94,14 @@ TEST_CASE(duplicate_test2)
TEST_CASE(depth_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto x1 = p.add_instruction(sum_op{}, one, two);
auto x2 = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, x1, x2);
p.add_instruction(minus_op{}, x1, x2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto x1 = mm->add_instruction(sum_op{}, one, two);
auto x2 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 4));
......@@ -113,15 +113,15 @@ TEST_CASE(depth_test)
TEST_CASE(undefined_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto undef = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::op::undefined{});
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
EXPECT(not p.has_instruction(undef));
EXPECT(not mm->has_instruction(undef));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -130,11 +130,11 @@ TEST_CASE(undefined_test)
TEST_CASE(duplicate_args1)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
p.add_instruction(migraphx::op::add{}, l3, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(0);
auto l3 = mm->add_literal(3);
mm->add_instruction(migraphx::op::add{}, l3, l3);
mm->add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count);
......@@ -146,12 +146,12 @@ TEST_CASE(duplicate_args1)
TEST_CASE(duplicate_args2)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3);
p.add_instruction(migraphx::op::add{}, sum1, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(0);
auto l3 = mm->add_literal(3);
auto sum1 = mm->add_instruction(migraphx::op::add{}, l0, l3);
mm->add_instruction(migraphx::op::add{}, sum1, l3);
mm->add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count);
......@@ -163,13 +163,13 @@ TEST_CASE(duplicate_args2)
TEST_CASE(duplicate_args3)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3);
auto sum2 = p.add_instruction(migraphx::op::add{}, l0, sum1);
p.add_instruction(migraphx::op::add{}, sum2, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(0);
auto l3 = mm->add_literal(3);
auto sum1 = mm->add_instruction(migraphx::op::add{}, l0, l3);
auto sum2 = mm->add_instruction(migraphx::op::add{}, l0, sum1);
mm->add_instruction(migraphx::op::add{}, sum2, l3);
mm->add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count);
......
......@@ -8,27 +8,32 @@
#include <migraphx/op/multibroadcast.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::decompose{}}); }
void run_pass(migraphx::program& p)
{
migraphx::run_passes(*p.get_main_module(), {migraphx::decompose{}});
}
TEST_CASE(dot_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto add = p2.add_instruction(migraphx::op::add{}, dot, z);
p2.add_instruction(migraphx::op::identity{}, add);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, z);
mm2->add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
......@@ -37,25 +42,27 @@ TEST_CASE(dot_add_beta_float)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta = mm2->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul);
mm2->add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
......@@ -64,25 +71,27 @@ TEST_CASE(dot_add_beta_half)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add);
mm2->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul);
mm2->add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
......@@ -91,25 +100,27 @@ TEST_CASE(dot_add_beta_double)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta = mm2->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul);
mm2->add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
......@@ -118,11 +129,12 @@ TEST_CASE(dot_add_beta_int)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot);
}
migraphx::program p2 = p1;
run_pass(p1);
......
......@@ -9,7 +9,8 @@
void run_pass(migraphx::program& p, std::size_t align = 32)
{
migraphx::run_passes(
p, {migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}});
*p.get_main_module(),
{migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}});
}
struct allocate
......@@ -39,14 +40,16 @@ struct allocate
TEST_CASE(basic)
{
migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2);
run_pass(p);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
......@@ -56,14 +59,16 @@ TEST_CASE(basic)
TEST_CASE(aligned)
{
migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2);
run_pass(p);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
......@@ -73,14 +78,16 @@ TEST_CASE(aligned)
TEST_CASE(unaligned)
{
migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2);
run_pass(p, 1);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
......@@ -90,14 +97,16 @@ TEST_CASE(unaligned)
TEST_CASE(float_aligned)
{
migraphx::program p;
auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2);
run_pass(p, 4);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
......
......@@ -8,29 +8,32 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(
p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
*p.get_main_module(),
{migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(cse_test1)
{
migraphx::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -39,23 +42,25 @@ TEST_CASE(cse_test2)
{
migraphx::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, two, one);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -64,21 +69,23 @@ TEST_CASE(cse_test3)
{
migraphx::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(1);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(1);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto one = p2.add_literal(1);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, one);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -87,24 +94,26 @@ TEST_CASE(cse_test4)
{
migraphx::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(1);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, one);
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum2, two);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5);
auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(1);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, one);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum2, two);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
mm1->add_instruction(pass_op{}, sum5);
}
run_pass(p1);
migraphx::program p2;
{
auto one = p2.add_literal(1);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, one);
auto sum5 = p2.add_instruction(migraphx::op::add{}, sum3, sum3);
p2.add_instruction(pass_op{}, sum5);
auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, one);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, one);
auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum3, sum3);
mm2->add_instruction(pass_op{}, sum5);
}
EXPECT(p1 == p2);
}
......@@ -113,30 +122,32 @@ TEST_CASE(cse_test_literal)
{
migraphx::program p1;
{
auto six1 = p1.add_literal(6);
auto zero1 = p1.add_literal(0);
auto six2 = p1.add_literal(6);
auto zero2 = p1.add_literal(0);
auto six3 = p1.add_literal(6);
auto zero3 = p1.add_literal(0);
auto* mm1 = p1.get_main_module();
auto six1 = mm1->add_literal(6);
auto zero1 = mm1->add_literal(0);
auto six2 = mm1->add_literal(6);
auto zero2 = mm1->add_literal(0);
auto six3 = mm1->add_literal(6);
auto zero3 = mm1->add_literal(0);
auto sum1 = p1.add_instruction(migraphx::op::add{}, six1, zero1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, six2, zero2);
auto sum3 = p1.add_instruction(migraphx::op::add{}, six3, zero3);
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum3, sum4);
p1.add_instruction(pass_op{}, sum5);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, six1, zero1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, six2, zero2);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, six3, zero3);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum3, sum4);
mm1->add_instruction(pass_op{}, sum5);
}
run_pass(p1);
migraphx::program p2;
{
auto six = p2.add_literal(6);
auto zero = p2.add_literal(0);
auto sum1 = p2.add_instruction(migraphx::op::add{}, six, zero);
auto sum2 = p2.add_instruction(migraphx::op::add{}, sum1, sum1);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto six = mm2->add_literal(6);
auto zero = mm2->add_literal(0);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, six, zero);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......
......@@ -47,7 +47,7 @@ struct concat_test_optimization
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p,
migraphx::run_passes(*p.get_main_module(),
{migraphx::eliminate_concat{concat_test_optimization{}},
migraphx::dead_code_elimination{}});
}
......@@ -106,23 +106,27 @@ TEST_CASE(simple)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = p.add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(1)});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = 0;
auto a3 = p.add_instruction(allocate{create_shape(2)});
p.add_instruction(concat(axis), p1, p2, a3);
auto a3 = mm->add_instruction(allocate{create_shape(2)});
mm->add_instruction(concat(axis), p1, p2, a3);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2)});
auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
p.add_instruction(identity{}, a1, p1, p2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(2)});
auto l1 = mm->add_instruction(load{create_shape(1), 0}, a1);
auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = mm->add_instruction(load{create_shape(1), 4}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p1, p2);
return p;
};
......@@ -137,13 +141,15 @@ TEST_CASE(negative_axis1)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2, 2)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = p.add_instruction(allocate{create_shape(2, 2)});
auto p2 = p.add_instruction(simple_op{}, a2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = -1;
auto a3 = p.add_instruction(allocate{create_shape(4, 2)});
p.add_instruction(concat(axis), p1, p2, a3);
auto a3 = mm->add_instruction(allocate{create_shape(4, 2)});
mm->add_instruction(concat(axis), p1, p2, a3);
return p;
};
auto create_control_program = create_test_program;
......@@ -159,23 +165,27 @@ TEST_CASE(negative_axis2)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2, 2)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = p.add_instruction(allocate{create_shape(2, 2)});
auto p2 = p.add_instruction(simple_op{}, a2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(2, 2)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = -2;
auto a3 = p.add_instruction(allocate{create_shape(4, 2)});
p.add_instruction(concat(axis), p1, p2, a3);
auto a3 = mm->add_instruction(allocate{create_shape(4, 2)});
mm->add_instruction(concat(axis), p1, p2, a3);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(4, 2)});
auto l1 = p.add_instruction(load{create_shape(2, 2), 0}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction(load{create_shape(2, 2), 16}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
p.add_instruction(identity{}, a1, p1, p2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(4, 2)});
auto l1 = mm->add_instruction(load{create_shape(2, 2), 0}, a1);
auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = mm->add_instruction(load{create_shape(2, 2), 16}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p1, p2);
return p;
};
......@@ -190,23 +200,27 @@ TEST_CASE(negative_axis3)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1, 2, 2)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = p.add_instruction(allocate{create_shape(1, 2, 2)});
auto p2 = p.add_instruction(simple_op{}, a2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(1, 2, 2)});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1, 2, 2)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = -2;
auto a3 = p.add_instruction(allocate{create_shape(1, 4, 2)});
p.add_instruction(concat(axis), p1, p2, a3);
auto a3 = mm->add_instruction(allocate{create_shape(1, 4, 2)});
mm->add_instruction(concat(axis), p1, p2, a3);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1, 4, 2)});
auto l1 = p.add_instruction(load{create_shape(1, 2, 2), 0}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction(load{create_shape(1, 2, 2), 16}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
p.add_instruction(identity{}, a1, p1, p2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(1, 4, 2)});
auto l1 = mm->add_instruction(load{create_shape(1, 2, 2), 0}, a1);
auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = mm->add_instruction(load{create_shape(1, 2, 2), 16}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p1, p2);
return p;
};
......@@ -221,23 +235,27 @@ TEST_CASE(reversed)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = p.add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(1)});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = 0;
auto a3 = p.add_instruction(allocate{create_shape(2)});
p.add_instruction(concat(axis), p2, p1, a3);
auto a3 = mm->add_instruction(allocate{create_shape(2)});
mm->add_instruction(concat(axis), p2, p1, a3);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(2)});
auto l1 = p.add_instruction(load{create_shape(1), 4}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction(load{create_shape(1), 0}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
p.add_instruction(identity{}, a1, p2, p1);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(2)});
auto l1 = mm->add_instruction(load{create_shape(1), 4}, a1);
auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = mm->add_instruction(load{create_shape(1), 0}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
mm->add_instruction(identity{}, a1, p2, p1);
return p;
};
......@@ -251,38 +269,42 @@ TEST_CASE(reversed)
TEST_CASE(nested)
{
auto concat_test_program = [](auto& p) {
auto a1 = p.add_instruction(allocate{create_shape(1)});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 = p.add_instruction(allocate{create_shape(1)});
auto p2 = p.add_instruction(simple_op{}, a2);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(1)});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(allocate{create_shape(1)});
auto p2 = mm->add_instruction(simple_op{}, a2);
std::size_t axis = 0;
auto a3 = p.add_instruction(allocate{create_shape(2)});
return p.add_instruction(concat(axis), p1, p2, a3);
auto a3 = mm->add_instruction(allocate{create_shape(2)});
return mm->add_instruction(concat(axis), p1, p2, a3);
};
auto create_test_program = [&] {
migraphx::program p;
auto* mm = p.get_main_module();
auto concat1 = concat_test_program(p);
auto concat2 = concat_test_program(p);
std::size_t axis = 0;
auto a1 = p.add_instruction(allocate{create_shape(4)});
p.add_instruction(concat(axis), concat1, concat2, a1);
auto a1 = mm->add_instruction(allocate{create_shape(4)});
mm->add_instruction(concat(axis), concat1, concat2, a1);
return p;
};
auto concat_control_program = [](auto& p, auto a1) {
auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
auto p2 = p.add_instruction(simple_op{}, l2);
return p.add_instruction(identity{}, a1, p1, p2);
auto* mm = p.get_main_module();
auto l1 = mm->add_instruction(load{create_shape(1), 0}, a1);
auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = mm->add_instruction(load{create_shape(1), 4}, a1);
auto p2 = mm->add_instruction(simple_op{}, l2);
return mm->add_instruction(identity{}, a1, p1, p2);
};
auto create_control_program = [&] {
migraphx::program p;
auto a1 = p.add_instruction(allocate{create_shape(4)});
auto l1 = p.add_instruction(load{create_shape(2), 0}, a1);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(allocate{create_shape(4)});
auto l1 = mm->add_instruction(load{create_shape(2), 0}, a1);
auto concat1 = concat_control_program(p, l1);
auto l2 = p.add_instruction(load{create_shape(2), 8}, a1);
auto l2 = mm->add_instruction(load{create_shape(2), 8}, a1);
auto concat2 = concat_control_program(p, l2);
p.add_instruction(identity{}, a1, concat1, concat2);
mm->add_instruction(identity{}, a1, concat1, concat2);
return p;
};
......@@ -297,35 +319,37 @@ TEST_CASE(basic)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(simple_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(simple_op{}, a3);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
auto p2 = mm->add_instruction(simple_op{}, a2);
auto a3 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
auto p3 = mm->add_instruction(simple_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(
auto a4 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
mm->add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 = p.add_instruction(
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction(
auto l1 = mm->add_instruction(
load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
auto p1 = p.add_instruction(simple_op{}, l1);
auto l2 = p.add_instruction(
auto p1 = mm->add_instruction(simple_op{}, l1);
auto l2 = mm->add_instruction(
load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
auto p2 = p.add_instruction(simple_op{}, l2);
auto l3 = p.add_instruction(
auto p2 = mm->add_instruction(simple_op{}, l2);
auto l3 = mm->add_instruction(
load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
auto p3 = p.add_instruction(simple_op{}, l3);
p.add_instruction(identity{}, {a1, p1, p2, p3});
auto p3 = mm->add_instruction(simple_op{}, l3);
mm->add_instruction(identity{}, {a1, p1, p2, p3});
return p;
};
......@@ -340,36 +364,38 @@ TEST_CASE(wont_work)
{
auto create_test_program = [] {
migraphx::program p;
auto a1 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(simple_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(simple_op{}, a3);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = mm->add_instruction(simple_op{}, a2);
auto a3 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = mm->add_instruction(simple_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(
auto a4 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
mm->add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = [] {
migraphx::program p;
auto a1 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(simple_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(simple_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(simple_op{}, a3);
auto* mm = p.get_main_module();
auto a1 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = mm->add_instruction(simple_op{}, a1);
auto a2 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = mm->add_instruction(simple_op{}, a2);
auto a3 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = mm->add_instruction(simple_op{}, a3);
std::size_t axis = 1;
auto a4 = p.add_instruction(
auto a4 = mm->add_instruction(
allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
mm->add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
......
......@@ -12,16 +12,19 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(*p.get_main_module(),
{migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
......@@ -30,10 +33,12 @@ TEST_CASE(standard_op)
TEST_CASE(standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_standard_op{}, c);
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
......@@ -41,10 +46,12 @@ TEST_CASE(standard_op_const)
TEST_CASE(non_standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
......@@ -53,10 +60,12 @@ TEST_CASE(non_standard_op)
TEST_CASE(non_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_instruction(pass_op{}, c);
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
......@@ -64,11 +73,13 @@ TEST_CASE(non_standard_op_const)
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto ic = mm->add_instruction(migraphx::op::identity{}, c);
mm->add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
......@@ -77,11 +88,13 @@ TEST_CASE(transpose_gemm)
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c);
mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
......@@ -90,11 +103,13 @@ TEST_CASE(transpose_standard_op)
TEST_CASE(transpose_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c);
mm->add_instruction(pass_standard_op{}, sn);
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 3);
}
......@@ -102,11 +117,13 @@ TEST_CASE(transpose_standard_op_const)
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c);
mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
......@@ -115,10 +132,12 @@ TEST_CASE(no_packed_unary_op)
TEST_CASE(non_standard_return_input)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto tl = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, tl);
p.add_return({c});
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto tl = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, tl);
mm->add_return({c});
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
......
......@@ -6,17 +6,22 @@
#include <migraphx/op/identity.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::eliminate_identity{}}); }
void run_pass(migraphx::program& p)
{
migraphx::run_passes(*p.get_main_module(), {migraphx::eliminate_identity{}});
}
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two);
p.add_instruction(sum_op{}, one_identity, two_identity);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto one_identity = mm->add_instruction(migraphx::op::identity{}, one);
auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::op::identity{}, two);
mm->add_instruction(sum_op{}, one_identity, two_identity);
run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
......@@ -29,10 +34,12 @@ TEST_CASE(simple_test_end)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(migraphx::op::identity{}, ans);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::op::identity{}, ans);
run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
......@@ -45,12 +52,14 @@ TEST_CASE(simple_test_end_dependency)
{
migraphx::program p;
auto one = p.add_literal(1.0);
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1.0);
auto two = mm->add_literal(2.0);
auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, ans, three);
mm->add_instruction(migraphx::op::identity{}, ans);
run_pass(p);
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
......
......@@ -8,7 +8,8 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(*p.get_main_module(),
{migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
}
migraphx::instruction_ref
......@@ -16,10 +17,10 @@ create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::progr
{
size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]);
auto* mm = p.get_main_module();
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
return p.add_instruction(migraphx::op::im2col{}, l_img, l_weights);
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
return mm->add_instruction(migraphx::op::im2col{}, l_img, l_weights);
}
migraphx::instruction_ref
......@@ -30,30 +31,30 @@ create_conv(migraphx::instruction_ref& l_img,
{
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
auto* mm = p.get_main_module();
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
return p.add_instruction(op, l_img, l_weights);
return mm->add_instruction(op, l_img, l_weights);
}
TEST_CASE(rewrite_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
auto l_img = mm->add_literal(migraphx::literal{s_img, input});
auto padded_img = mm->add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
auto l0 = create_im2col(padded_img, channels, p);
auto l1 = create_conv(padded_img, channels, p);
auto l2 = p.add_instruction(migraphx::op::pooling{"max"}, padded_img);
p.add_instruction(migraphx::op::identity{}, l0, l1, l2);
auto l2 = mm->add_instruction(migraphx::op::pooling{"max"}, padded_img);
mm->add_instruction(migraphx::op::identity{}, l0, l1, l2);
run_pass(p);
EXPECT(std::none_of(
......@@ -63,14 +64,16 @@ TEST_CASE(rewrite_test)
TEST_CASE(rewrite_test_asymmetric)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img);
auto l_img = mm->add_literal(migraphx::literal{s_img, input});
auto padded_img = mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img);
create_im2col(padded_img, channels, p);
......
......@@ -71,7 +71,7 @@ struct reverse_pass
{
std::string name() const { return "reverse_pass"; }
void apply(migraphx::program& p) const { std::reverse(p.begin(), p.end()); }
void apply(migraphx::module& p) const { std::reverse(p.begin(), p.end()); }
};
struct reverse_target
......@@ -89,7 +89,7 @@ struct invert_pass
{
std::string name() const { return "invert_pass"; }
void apply(migraphx::program& p) const
void apply(migraphx::module& p) const
{
for(auto ins : migraphx::iterator_for(p))
{
......@@ -130,10 +130,10 @@ struct double_invert_target
TEST_CASE(literal_test1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -142,11 +142,11 @@ TEST_CASE(literal_test1)
TEST_CASE(literal_test2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, sum1, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, sum1, two);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5});
......@@ -156,10 +156,10 @@ TEST_CASE(literal_test2)
TEST_CASE(print_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, x, two);
std::stringstream ss;
ss << p;
......@@ -170,11 +170,11 @@ TEST_CASE(print_test)
TEST_CASE(param_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
mm->add_instruction(sum_op{}, x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}})
.back();
......@@ -185,11 +185,11 @@ TEST_CASE(param_test)
TEST_CASE(param_error_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
mm->add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({{"x", migraphx::literal{1}.get_argument()}});
......@@ -200,11 +200,11 @@ TEST_CASE(param_error_test)
TEST_CASE(param_error_shape_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
p.add_instruction(sum_op{}, x, y);
mm->add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({
......@@ -218,10 +218,11 @@ TEST_CASE(param_error_shape_test)
TEST_CASE(get_param1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
p.add_instruction(sum_op{}, x, y);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y);
EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
......@@ -230,19 +231,21 @@ TEST_CASE(get_param1)
TEST_CASE(get_param2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
TEST_CASE(get_param_shapes)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
p.add_instruction(sum_op{}, x, y);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y);
auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s);
......@@ -252,11 +255,11 @@ TEST_CASE(get_param_shapes)
TEST_CASE(replace_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.replace_instruction(sum, minus_op{}, two, one);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back();
......@@ -267,12 +270,12 @@ TEST_CASE(replace_test)
TEST_CASE(replace_ins_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one);
p.replace_instruction(sum, minus);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back();
......@@ -283,13 +286,13 @@ TEST_CASE(replace_ins_test)
TEST_CASE(replace_ins_test2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one);
p.add_instruction(pass_op{}, minus);
p.replace_instruction(two, sum);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back();
......@@ -300,10 +303,10 @@ TEST_CASE(replace_ins_test2)
TEST_CASE(replace_op_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, two, one);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one);
sum->replace(minus_op{});
EXPECT(bool{p.validate() == p.end()});
......@@ -315,24 +318,24 @@ TEST_CASE(replace_op_test)
TEST_CASE(replace_op_recompute_shape_throw)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
}
TEST_CASE(insert_replace_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, sum1, two);
auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
p.replace_instruction(sum1, minus_op{}, sum0, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, sum1, two);
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back();
......@@ -343,12 +346,12 @@ TEST_CASE(insert_replace_test)
TEST_CASE(remove_test1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto removed = p.add_instruction(minus_op{}, sum, one);
p.remove_instruction(removed);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back();
......@@ -359,12 +362,12 @@ TEST_CASE(remove_test1)
TEST_CASE(remove_test2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto removed = p.add_instruction(minus_op{}, two, one);
p.add_instruction(sum_op{}, one, two);
p.remove_instruction(removed);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto removed = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(sum_op{}, one, two);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}).back();
......@@ -375,10 +378,10 @@ TEST_CASE(remove_test2)
TEST_CASE(target_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
p.compile(id_target{});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......@@ -388,10 +391,10 @@ TEST_CASE(target_test)
TEST_CASE(invert_target_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, two, one);
p.compile(invert_target{});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1});
......@@ -401,10 +404,10 @@ TEST_CASE(invert_target_test)
TEST_CASE(double_invert_target_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, two, one);
p.compile(double_invert_target{});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......@@ -414,10 +417,10 @@ TEST_CASE(double_invert_target_test)
TEST_CASE(reverse_target_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
EXPECT(test::throws<migraphx::exception>([&] { p.compile(reverse_target{}); }));
}
......@@ -426,11 +429,12 @@ TEST_CASE(reverse_target_test)
TEST_CASE(eval_context1)
{
migraphx::program p;
auto* mm = p.get_main_module();
id_target t{};
EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back();
......@@ -440,11 +444,12 @@ TEST_CASE(eval_context1)
TEST_CASE(eval_context2)
{
migraphx::program p;
auto* mm = p.get_main_module();
id_target t{};
EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(id_ctx_op{}, one, two);
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(id_ctx_op{}, one, two);
p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back();
......@@ -455,11 +460,12 @@ TEST_CASE(eval_context2)
TEST_CASE(eval_context3)
{
migraphx::program p;
auto* mm = p.get_main_module();
id_target t{};
EXPECT(is_shared(t.ctx, t.get_context()));
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(id_ctx_final_op{}, one, two);
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(id_ctx_final_op{}, one, two);
p.compile(t);
// Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context()));
......@@ -495,11 +501,13 @@ std::string capture_output(F f)
TEST_CASE(debug_print_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
std::vector<migraphx::instruction_ref> onev = {one};
migraphx::program p2;
auto one2 = p2.add_literal(1);
auto* mm2 = p2.get_main_module();
auto one2 = mm2->add_literal(1);
auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); }));
auto ins_out = migraphx::trim(capture_output([&] { p.debug_print(one); }));
......
......@@ -18,7 +18,7 @@
void run_lowering(migraphx::program& p)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(p,
migraphx::run_passes(*p.get_main_module(),
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{},
......@@ -30,12 +30,13 @@ TEST_CASE(tanh_shape)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = p.add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx);
auto sum = p.add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum);
auto x = mm->add_parameter("x", s);
auto tx = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = mm->add_instruction(migraphx::op::tanh{}, tx);
auto sum = mm->add_instruction(migraphx::op::add{}, txh, txh);
mm->add_instruction(migraphx::op::contiguous{}, sum);
return p;
};
......@@ -59,7 +60,7 @@ TEST_CASE(tanh_shape)
}
EXPECT(p1 != p2);
migraphx::run_passes(p2,
migraphx::run_passes(*p2.get_main_module(),
{migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
}
......
......@@ -12,6 +12,7 @@
migraphx::program create_gelu()
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data0 = {0.044715};
std::vector<float> data1 = {0.797885};
std::vector<float> data2 = {3};
......@@ -20,25 +21,25 @@ migraphx::program create_gelu()
std::vector<size_t> x_dims{1, 1, 5};
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = p.add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = p.add_literal(migraphx::literal{s0, data1});
auto three_val = p.add_literal(migraphx::literal{s0, data2});
auto half_val = p.add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = p.add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = p.add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = p.add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = p.add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = p.add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = p.add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = p.add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = p.add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = p.add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = p.add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = p.add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = p.add_instruction(migraphx::op::mul{}, x, add_mul_half);
p.add_return({mul_x});
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = mm->add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = mm->add_literal(migraphx::literal{s0, data1});
auto three_val = mm->add_literal(migraphx::literal{s0, data2});
auto half_val = mm->add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = mm->add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = mm->add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = mm->add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = mm->add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = mm->add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = mm->add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = mm->add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = mm->add_instruction(migraphx::op::mul{}, x, add_mul_half);
mm->add_return({mul_x});
return p;
}
......
......@@ -9,8 +9,9 @@
void gpu_literal_test()
{
migraphx::program p;
auto* mm = p.get_main_module();
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit);
mm->add_literal(lit);
p.compile(migraphx::gpu::target{});
auto scratch = p.get_parameter("scratch");
if(scratch == p.end())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment