"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0b7469d3feeac96331ba38fc8c042bee03f6b42c"
Commit 98fd5e1d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into eliminate-more-contiguous

parents f7a6d87f a1c7e7a5
...@@ -26,7 +26,6 @@ struct tf_parser ...@@ -26,7 +26,6 @@ struct tf_parser
{ {
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>; using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>; using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
...@@ -149,9 +148,26 @@ struct tf_parser ...@@ -149,9 +148,26 @@ struct tf_parser
return axes; return axes;
} }
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser() tf_parser()
{ {
add_generic_op("All", op::identity{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{}); add_generic_op("Rsqrt", op::rsqrt{});
...@@ -166,6 +182,7 @@ struct tf_parser ...@@ -166,6 +182,7 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false); add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false); add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
...@@ -177,14 +194,15 @@ struct tf_parser ...@@ -177,14 +194,15 @@ struct tf_parser
add_mem_op("GatherV2", &tf_parser::parse_gather, false); add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean, false);
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false); add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>); add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false); add_mem_op("Transpose", &tf_parser::parse_transpose, false);
} }
...@@ -547,7 +565,7 @@ struct tf_parser ...@@ -547,7 +565,7 @@ struct tf_parser
} }
if(contains(attributes, "transpose_b")) if(contains(attributes, "transpose_b"))
{ {
transb = attributes.at("transpose_a").b(); transb = attributes.at("transpose_b").b();
} }
if(contains(attributes, "adj_x")) if(contains(attributes, "adj_x"))
...@@ -574,8 +592,7 @@ struct tf_parser ...@@ -574,8 +592,7 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
auto lens = args[0]->get_shape().lens(); auto axes = args[1]->eval().get<int32_t>().to_vector<int64_t>();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
if(keep_dims) if(keep_dims)
{ {
...@@ -588,6 +605,32 @@ struct tf_parser ...@@ -588,6 +605,32 @@ struct tf_parser
} }
} }
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -799,21 +842,50 @@ struct tf_parser ...@@ -799,21 +842,50 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::slice op; op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0; uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1; uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes; std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask")) if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i()); shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l1 = prog.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
// the LSB corresponds to axis 0 when determining which axes to squeeze // the LSB corresponds to axis 0 when determining which axes to squeeze
...@@ -821,8 +893,7 @@ struct tf_parser ...@@ -821,8 +893,7 @@ struct tf_parser
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
auto l0 = prog.add_instruction(op, make_contiguous(args[0])); return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
instruction_ref instruction_ref
...@@ -862,10 +933,16 @@ struct tf_parser ...@@ -862,10 +933,16 @@ struct tf_parser
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); auto&& iname = get_name(nodes.at(input));
......
...@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2) ...@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2)
} }
} }
TEST_CASE(quant_dot_2args_multi4)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {112, 118, 124, 130, 136, 142, 148, 154, 304, 326, 348,
370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686,
724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704,
736, 768, 592, 628, 664, 700, 736, 772, 808, 844};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822,
974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}};
std::vector<int8_t> data1(4 * 4);
std::vector<int8_t> data2(4 * 8);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{}, tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708,
836, 964, 74, 218, 362, 506, 650, 794, 938, 1082};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_2args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2);
std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
p.add_instruction(migraphx::op::quant_dot{}, tl1, l2);
std::vector<int> gold = {
210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(3 * 4);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(
migraphx::op::quant_dot{
2,
},
l1,
tl2);
std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}};
std::vector<int8_t> data1(4 * 3);
std::vector<int8_t> data2(4 * 5);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2);
std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_general)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
std::vector<int8_t> data1(2 * 8);
std::vector<int8_t> data2(8 * 7);
std::vector<int> data3(2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(quant_dot_3args_batch)
{
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 2, 4}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 4, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 2, 7}};
std::vector<int8_t> data1(4 * 2 * 4);
std::vector<int8_t> data2(4 * 4 * 7);
std::vector<int> data3(4 * 2 * 7);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
404, 428, 1530, 1570, 1610, 1650, 1690, 1730, 1770, 2160, 2216, 2272,
2328, 2384, 2440, 2496, 4750, 4822, 4894, 4966, 5038, 5110, 5182, 5828,
5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 4, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 6, 4}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 3, 6}};
std::vector<int8_t> data1(48);
std::vector<int8_t> data2(96);
std::vector<int> data3(72);
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = p.add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_literal(migraphx::literal{m3_shape, data3});
p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
150, 361, 572, 783, 994, 1205, 3456, 3987, 4518, 5049, 5580, 6111,
3678, 4241, 4804, 5367, 5930, 6493, 3900, 4495, 5090, 5685, 6280, 6875,
11430, 12345, 13260, 14175, 15090, 16005, 11844, 12791, 13738, 14685, 15632, 16579,
12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -1483,6 +1484,107 @@ TEST_CASE(conv2d_padding_stride_test) ...@@ -1483,6 +1484,107 @@ TEST_CASE(conv2d_padding_stride_test)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(quant_conv2d_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893,
51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(quant_conv2d_padding_stride_test)
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int32_t> s = {4521,
7014,
7830,
11952,
10515,
16734,
19737,
30906,
13161,
19542,
19494,
28800,
34707,
52590,
54729,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(transpose_test) TEST_CASE(transpose_test)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}};
...@@ -1927,4 +2029,58 @@ TEST_CASE(sqdiff_test) ...@@ -1927,4 +2029,58 @@ TEST_CASE(sqdiff_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(round_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
p.add_instruction(migraphx::op::round{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
{
std::cout << v << "\t";
}
std::cout << std::endl;
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(op_capture)
{
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = p.add_literal(s1, d1);
auto p2 = p.add_literal(s1, d1);
auto pb = p.add_literal(s2, d2);
auto pc = p.add_literal(s2, d2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
migraphx::program capture_p = p;
migraphx::capture_arguments(capture_p);
p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{});
auto cap_res = capture_p.eval({});
auto res = p.eval({});
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1386,6 +1386,114 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -1386,6 +1386,114 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
} }
}; };
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
return p;
}
};
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
return p;
}
};
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
return p;
}
};
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous> struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1515,6 +1623,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -1515,6 +1623,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
} }
}; };
struct quant_conv : verify_program<quant_conv>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{}, pa, pc);
return p;
}
};
struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p;
}
};
struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
pa,
pc);
return p;
}
};
struct quant_conv_padding : verify_program<quant_conv_padding>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc);
return p;
}
};
struct test_concat : verify_program<test_concat> struct test_concat : verify_program<test_concat>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -3625,6 +3810,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean> ...@@ -3625,6 +3810,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
}; };
}; };
struct test_reduce_mean2 : verify_program<test_reduce_mean2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 128, 768}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int> struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -3649,4 +3846,33 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half> ...@@ -3649,4 +3846,33 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
}; };
}; };
struct test_round : verify_program<test_round>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::round{}, param);
return p;
}
};
struct test_convert : verify_program<test_convert>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {8, 24}};
migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
namespace match = migraphx::match; namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M> template <class M>
migraphx::match::matcher_result find_match(migraphx::program& p, M&& m) migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
{ {
...@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3) ...@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3) ...@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto one = p.add_literal(1);
p.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2) ...@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
EXPECT(not migraphx::contains(r.instructions, "y"));
}
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
EXPECT(bool{r.instructions["x1"] == one});
EXPECT(bool{r.instructions["y1"] == two});
EXPECT(not migraphx::contains(r.instructions, "x2"));
EXPECT(not migraphx::contains(r.instructions, "y2"));
}
TEST_CASE(match_none_of1) TEST_CASE(match_none_of1)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -4,9 +4,15 @@ ...@@ -4,9 +4,15 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(conv_autopad_fail_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("conv_autopad_fail_test.onnx"); }));
}
TEST_CASE(pytorch_conv_bias_test) TEST_CASE(pytorch_conv_bias_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -509,6 +515,32 @@ TEST_CASE(shape_gather_test) ...@@ -509,6 +515,32 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto make_contiguous = [&p](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return p.add_instruction(migraphx::op::contiguous{}, ins);
};
auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
int axis = 1;
p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1022,4 +1054,14 @@ TEST_CASE(expand_test) ...@@ -1022,4 +1054,14 @@ TEST_CASE(expand_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(round_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input);
auto prog = migraphx::parse_onnx("round_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
 round-example:E
xy"Round
test_roundZ
x
 

b
y
 

B
\ No newline at end of file
...@@ -76,6 +76,26 @@ TEST_CASE(convolution_shape) ...@@ -76,6 +76,26 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::op::convolution{}, input2, weights); throws_shape(migraphx::op::convolution{}, input2, weights);
} }
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::op::quant_convolution{}, input, weights);
throws_shape(migraphx::op::quant_convolution{}, input);
migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input2, weights2);
throws_shape(migraphx::op::quant_convolution{}, input2, weights);
migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::op::quant_convolution{}, input3, weights);
throws_shape(migraphx::op::quant_convolution{}, input, weight3);
throws_shape(migraphx::op::quant_convolution{}, input3, weight3);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
...@@ -679,6 +699,61 @@ TEST_CASE(gemm) ...@@ -679,6 +699,61 @@ TEST_CASE(gemm)
} }
} }
// quant_dot
TEST_CASE(quant_dot_2args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::op::quant_dot{1, 0},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2);
}
}
TEST_CASE(quant_dot_3args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::op::quant_dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3);
}
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
...@@ -150,4 +150,55 @@ TEST_CASE(simplify_mul_conv1) ...@@ -150,4 +150,55 @@ TEST_CASE(simplify_mul_conv1)
EXPECT(new_conv->outputs().front()->name() != "mul"); EXPECT(new_conv->outputs().front()->name() != "mul");
} }
TEST_CASE(simplify_mul_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, x);
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto mul1 = p2.add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = p2.add_instruction(migraphx::op::mul{}, two, one);
auto sum = p2.add_instruction(migraphx::op::add{}, mul1, mul2);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_inner_broadcast)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto xb = p1.add_instruction(b, x);
auto yb = p1.add_instruction(b, y);
auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb);
p1.add_instruction(pass_op{}, sum);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
auto sumb = p2.add_instruction(b, sum);
p2.add_instruction(pass_op{}, sumb);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -48,6 +48,22 @@ TEST_CASE(add_bcast_test) ...@@ -48,6 +48,22 @@ TEST_CASE(add_bcast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(assert_less_equal_test)
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", s0);
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
auto l2 = p.add_literal(l);
p.add_instruction(migraphx::op::add{}, l0, l1);
auto l3 = p.add_instruction(migraphx::op::identity{}, l0, l1);
p.add_instruction(migraphx::op::identity{}, l3, l2);
auto prog = optimize_tf("assert_less_equal_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(batchmatmul_test) TEST_CASE(batchmatmul_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -100,6 +116,16 @@ TEST_CASE(biasadd_test) ...@@ -100,6 +116,16 @@ TEST_CASE(biasadd_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(concat_test) TEST_CASE(concat_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -118,16 +144,6 @@ TEST_CASE(concat_test) ...@@ -118,16 +144,6 @@ TEST_CASE(concat_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(const_test) TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -271,9 +287,10 @@ TEST_CASE(mean_test_nhwc) ...@@ -271,9 +287,10 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::reduce_mean op{{2, 3}}; auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l3 = p.add_instruction(op, l0); migraphx::op::reduce_mean op{{1, 2}};
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
auto prog = optimize_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -291,6 +308,23 @@ TEST_CASE(mul_test) ...@@ -291,6 +308,23 @@ TEST_CASE(mul_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(onehot_test)
{
migraphx::program p;
auto l0 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2);
p.add_literal(1.0f);
p.add_literal(0.0f);
auto l1 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_tf("onehot_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test) TEST_CASE(pack_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -475,20 +509,44 @@ TEST_CASE(stridedslice_test) ...@@ -475,20 +509,44 @@ TEST_CASE(stridedslice_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5}; op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
auto l1 = p.add_instruction(op, l0); auto l2 = p.add_instruction(op, l1);
auto shrink_axis = 1; auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1); p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
auto prog = optimize_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(stridedslice_masks_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 1, 1, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(sub_test) TEST_CASE(sub_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -202,4 +202,55 @@ TEST_CASE(literal_add) ...@@ -202,4 +202,55 @@ TEST_CASE(literal_add)
} }
} }
TEST_CASE(op_capture)
{
auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index;
(void)args;
};
auto create_program_float = [] {
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
auto p1 = p.add_parameter("x", s1);
auto p2 = p.add_parameter("y", s1);
auto pb = p.add_parameter("b", s2);
auto pc = p.add_parameter("c", s2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
return p;
};
auto create_program_op = [&] {
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
auto p1 = p.add_parameter("x", s1);
auto p2 = p.add_parameter("y", s1);
auto pb = p.add_parameter("b", s2);
auto pc = p.add_parameter("c", s2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto opb = p.insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
auto opc = p.insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
auto opa = p.add_instruction(migraphx::op::capture{0, test_func}, pa);
auto ps = p.add_instruction(migraphx::op::dot{}, opa, opb, opc);
auto ops = p.add_instruction(migraphx::op::capture{3, test_func}, ps);
p.add_instruction(migraphx::op::dot{}, opa, ops);
return p;
};
{
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::capture_arguments(p);
EXPECT(p == op_capture_p);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment