Commit 7dc6e3ae authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents f94d77fc a275f590
......@@ -7,19 +7,6 @@ namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M>
migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
{
migraphx::match::matcher_result result;
for(auto ins : migraphx::iterator_for(modl))
{
result = migraphx::match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
void match1()
{
migraphx::module mm;
......
......@@ -253,4 +253,27 @@ TEST_CASE(submodule_copy)
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
}
TEST_CASE(parameter_name_order)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module mm("main");
auto x1 = mm.add_parameter("x1", s);
auto x2 = mm.add_parameter("x2", s);
auto x3 = mm.add_parameter("x3", s);
auto x4 = mm.add_parameter("x4", s);
std::vector<std::string> param_names = {"x1", "x2", "x3", "x4"};
auto sum1 = mm.add_instruction(migraphx::make_op("add"), x1, x2);
auto sum2 = mm.add_instruction(migraphx::make_op("add"), x3, x4);
auto r = mm.add_instruction(migraphx::make_op("mul"), sum1, sum2);
mm.add_return({r});
auto names = mm.get_parameter_names();
EXPECT(param_names == names);
auto m1 = mm;
auto names1 = m1.get_parameter_names();
EXPECT(param_names == names1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2346,6 +2346,84 @@ def logsoftmax_nonstd_input_test():
return ([node0, node1], [x], [y])
@onnx_test
def loop_default_test():
body = helper.make_graph([
helper.make_node("Add", ["a", "b_in"], ["my_local"]),
helper.make_node("Sub", ["a", "b_in"], ["a_sub_b_in"]),
helper.make_node("Greater", ["my_local", "a_sub_b_in"],
["keep_going"]),
helper.make_node("Add", ["a_sub_b_in", "a_sub_b_in"],
["user_defined_vals"]),
], "body", [
helper.make_tensor_value_info('iteration_num', TensorProto.INT64, []),
helper.make_tensor_value_info('keep_going_inp', TensorProto.BOOL, []),
helper.make_tensor_value_info('b_in', TensorProto.FLOAT, [])
], [
helper.make_tensor_value_info('keep_going', TensorProto.BOOL, []),
helper.make_tensor_value_info('a_sub_b_in', TensorProto.FLOAT, []),
helper.make_tensor_value_info('my_local', TensorProto.FLOAT, []),
helper.make_tensor_value_info('user_defined_vals', TensorProto.FLOAT,
[]),
])
node = helper.make_node(
"Loop",
inputs=["", "", "b"],
outputs=["b_loop", "my_local_loop", "user_defined_vals_loop"],
body=body)
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [])
b_loop = helper.make_tensor_value_info('b_loop', TensorProto.FLOAT, [])
uout = helper.make_tensor_value_info('user_defined_vals_loop',
TensorProto.FLOAT, [2, 1])
return ([node], [a, b], [b_loop, uout])
@onnx_test
def loop_test():
body = helper.make_graph([
helper.make_node("Add", ["a", "b_in"], ["my_local"]),
helper.make_node("Sub", ["a", "b_in"], ["a_sub_b_in"]),
helper.make_node("Greater", ["my_local", "a_sub_b_in"],
["keep_going"]),
helper.make_node("Add", ["a_sub_b_in", "a_sub_b_in"],
["user_defined_vals"]),
], "body", [
helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]),
helper.make_tensor_value_info('keep_going_inp', TensorProto.BOOL, [1]),
helper.make_tensor_value_info('b_in', TensorProto.FLOAT, [1])
], [
helper.make_tensor_value_info('keep_going', TensorProto.BOOL, [1]),
helper.make_tensor_value_info('a_sub_b_in', TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('my_local', TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('user_defined_vals', TensorProto.FLOAT,
[1]),
])
node = helper.make_node(
"Loop",
inputs=["max_trip_count", "keep_going_cond", "b"],
outputs=["b_loop", "my_local_loop", "user_defined_vals_loop"],
body=body)
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [1])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [1])
cond = helper.make_tensor_value_info('keep_going_cond', TensorProto.BOOL,
[1])
iter = helper.make_tensor_value_info('max_trip_count', TensorProto.INT64,
[1])
b_loop = helper.make_tensor_value_info('b_loop', TensorProto.FLOAT, [1])
uout = helper.make_tensor_value_info('user_defined_vals_loop',
TensorProto.FLOAT, [2, 1])
return ([node], [iter, cond, a, b], [b_loop, uout])
@onnx_test
def lrn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24])
......@@ -4126,6 +4204,46 @@ def tanh_test():
return ([node], [x], [y])
@onnx_test
def thresholdedrelu_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
node = onnx.helper.make_node('ThresholdedRelu',
inputs=['x'],
outputs=['y'])
return ([node], [x], [y])
@onnx_test
def thresholdedrelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
alpha = 3.0
node = onnx.helper.make_node('ThresholdedRelu',
inputs=['x'],
outputs=['y'],
alpha=alpha)
return ([node], [x], [y])
@onnx_test
def thresholdedrelu_int_test():
x = helper.make_tensor_value_info('x', TensorProto.INT32, [2, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [2, 2, 3])
alpha = 3.0
node = onnx.helper.make_node('ThresholdedRelu',
inputs=['x'],
outputs=['y'],
alpha=alpha)
return ([node], [x], [y])
@onnx_test
def tile_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2])
......@@ -4150,6 +4268,90 @@ def tile_test_3x2():
[helper.make_tensor('y', TensorProto.INT64, [2], [3, 2])])
@onnx_test
def topk_attrk_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 5, 3, 2])
val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [2, 2, 3, 2])
ind = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 2, 3, 2])
node = onnx.helper.make_node('TopK',
inputs=['data'],
outputs=['val', 'indices'],
k=2)
return ([node], [x], [val, ind])
@onnx_test
def topk_neg_axis_test():
k = np.array([3])
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [3, 3, 5, 6])
ind = helper.make_tensor_value_info('indices', TensorProto.INT64,
[3, 3, 5, 6])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('TopK',
inputs=['data', 'k'],
outputs=['val', 'indices'],
axis=-2,
sorted=0)
return ([node], [x], [val, ind], [k_tensor])
@onnx_test
def topk_test():
k = np.array([4])
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 5, 3, 2])
val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [2, 4, 3, 2])
ind = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 4, 3, 2])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('TopK',
inputs=['data', 'k'],
outputs=['val', 'indices'],
largest=0,
axis=1)
return ([node], [x], [val, ind], [k_tensor])
def transpose_default_perm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 5, 2, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 5, 1])
node = onnx.helper.make_node(
'Transpose',
inputs=['0'],
outputs=['1'],
)
return ([node], [x], [y])
@onnx_test
def transpose_invalid_perm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 4, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 2, 2])
node = onnx.helper.make_node(
'Transpose',
perm=[0, 2, 1],
inputs=['0'],
outputs=['1'],
)
return ([node], [x], [y])
@onnx_test
def transpose_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 2, 3])
......
 loop_test:

max_trip_count
keep_going_cond
bb_loop my_local_loopuser_defined_vals_loop"Loop*
body2

a
b_inmy_local"Add

a
b_in
a_sub_b_in"Sub
+
my_local
a_sub_b_in
keep_going"Greater
0
a_sub_b_in
a_sub_b_inuser_defined_vals"AddbodyZ
iteration_num

Z
keep_going_inp
 
Z
b_in

b
keep_going
 
b
a_sub_b_in

b
my_local

b
user_defined_vals

 loop_testZ
max_trip_count

Z
keep_going_cond
 
Z
a

Z
b

b
b_loop

b(
user_defined_vals_loop


B
\ No newline at end of file
......@@ -74,7 +74,7 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", l0->get_shape().lens()}}), l1);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_onnx("add_bcast_test.onnx");
......@@ -102,8 +102,8 @@ TEST_CASE(add_scalar_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint8_type});
auto m1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
auto r = mm->add_instruction(migraphx::make_op("add"), l0, m1);
mm->add_return({r});
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
......@@ -373,9 +373,9 @@ TEST_CASE(clip_test)
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_onnx("clip_test.onnx");
......@@ -390,7 +390,7 @@ TEST_CASE(clip_test_op11_max_only)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::make_op("undefined"));
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
auto r = mm->add_instruction(migraphx::make_op("min"), l0, max_val);
mm->add_return({r});
......@@ -407,9 +407,9 @@ TEST_CASE(clip_test_op11)
auto max_val = mm->add_literal(6.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_onnx("clip_test_op11.onnx");
......@@ -423,7 +423,7 @@ TEST_CASE(clip_test_op11_min_only)
auto min_val = mm->add_literal(0.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
mm->add_instruction(migraphx::make_op("max"), l0, min_val);
auto prog = optimize_onnx("clip_test_op11_min_only.onnx");
......@@ -638,7 +638,7 @@ TEST_CASE(conv_bias_test)
uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto prog = optimize_onnx("conv_bias_test.onnx");
......@@ -661,7 +661,7 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(
migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6);
......@@ -687,7 +687,7 @@ TEST_CASE(conv_relu_maxpool_test)
auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
mm->add_instruction(
......@@ -711,7 +711,7 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
auto l7 = mm->add_instruction(
......@@ -725,7 +725,8 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l10 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l7, l8);
auto l11 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l10->get_shape().lens()}}), l9);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l10->get_shape().lens()}}),
l9);
auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11);
auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12);
mm->add_instruction(
......@@ -749,7 +750,7 @@ TEST_CASE(convinteger_bias_test)
uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("quant_convolution"), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto prog = optimize_onnx("convinteger_bias_test.onnx");
......@@ -801,7 +802,7 @@ TEST_CASE(deconv_bias_test)
uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("deconvolution"), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto prog = optimize_onnx("deconv_bias_test.onnx");
......@@ -923,7 +924,7 @@ TEST_CASE(dequantizelinear_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto dequant = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
......@@ -942,9 +943,9 @@ TEST_CASE(dequantizelinear_zero_point_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
......@@ -971,9 +972,9 @@ migraphx::program make_dequantizelinear_axis_prog()
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1);
auto l2_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2);
l2_bcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
......@@ -1129,8 +1130,7 @@ TEST_CASE(expand_test)
auto param = mm->add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4});
mm->add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}),
param);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), param);
auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog);
......@@ -1150,7 +1150,7 @@ migraphx::program create_external_data_prog()
auto conv = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), param, weights);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 10, 214, 214}}}), bias);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 10, 214, 214}}}), bias);
mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast);
return p;
}
......@@ -1188,8 +1188,9 @@ TEST_CASE(flatten_nonstd_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5, 4}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
auto l2 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
auto l2 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l2);
auto l3 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l3);
......@@ -1240,7 +1241,7 @@ TEST_CASE(gather_elements_axis0_test)
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride);
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
......@@ -1269,7 +1270,7 @@ TEST_CASE(gather_elements_axis1_test)
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride);
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
......@@ -1292,17 +1293,19 @@ TEST_CASE(gemm_test)
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, t1, l2_bb);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
......@@ -1318,16 +1321,17 @@ TEST_CASE(gemm_ex_test)
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2->get_shape().lens()}}), b_l);
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog);
}
......@@ -1343,19 +1347,19 @@ TEST_CASE(gemm_ex_brcst_test)
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_bb);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog);
}
......@@ -1372,21 +1376,21 @@ TEST_CASE(gemm_half_test)
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), l2);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta);
auto b_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), b_l);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_half_test.onnx");
EXPECT(p == prog);
}
......@@ -1666,20 +1670,20 @@ TEST_CASE(if_tuple_test)
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto m1 =
then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto m2 =
then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto me1 =
else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto me2 =
else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
......@@ -1705,7 +1709,7 @@ TEST_CASE(imagescaler_test)
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_test.onnx");
......@@ -1727,7 +1731,7 @@ TEST_CASE(imagescaler_half_test)
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_half_test.onnx");
......@@ -1741,8 +1745,8 @@ TEST_CASE(implicit_add_bcast_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l3);
auto prog = optimize_onnx("implicit_add_bcast_test.onnx");
......@@ -1756,8 +1760,8 @@ TEST_CASE(implicit_add_bcast_user_input_shape_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}});
auto l3 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4, 5, 6}}}), l1);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 5, 6}}}), l1);
auto r = mm->add_instruction(migraphx::make_op("add"), l0, l3);
mm->add_return({r});
......@@ -1775,8 +1779,8 @@ TEST_CASE(implicit_pow_bcast_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("pow"), l0, l3);
auto prog = optimize_onnx("implicit_pow_bcast_test.onnx");
......@@ -1790,8 +1794,8 @@ TEST_CASE(implicit_sub_bcast_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint64_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint64_type, {4, 5}});
auto l3 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l3);
auto prog = optimize_onnx("implicit_sub_bcast_test.onnx");
......@@ -1806,7 +1810,7 @@ TEST_CASE(initializer_not_an_input)
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
mm->add_instruction(migraphx::make_op("dot"), l0, l1);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, l1);
auto prog = optimize_onnx("initializer_not_an_input.onnx");
......@@ -1827,22 +1831,22 @@ TEST_CASE(instance_norm_test)
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), mean);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto epsilon_literal = mm->add_literal(1e-5f);
auto epsilon_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3);
auto scale_bcast =
mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
auto bias_bcast =
mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3);
auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
......@@ -1940,7 +1944,7 @@ TEST_CASE(logical_and_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 5}});
auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1);
migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2);
mm->add_return({ret});
......@@ -1970,7 +1974,7 @@ TEST_CASE(logical_xor_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 1}});
auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1);
migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2);
mm->add_return({ret});
......@@ -2006,6 +2010,82 @@ TEST_CASE(logsoftmax_nonstd_input_test)
EXPECT(p == prog);
}
TEST_CASE(loop_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape su{migraphx::shape::float_type};
auto a = mm->add_parameter("a", su);
auto b = mm->add_parameter("b", su);
migraphx::shape si{migraphx::shape::int64_type};
auto max_iter = mm->add_literal(migraphx::literal(si, {10}));
migraphx::shape sc{migraphx::shape::bool_type};
auto icond = mm->add_literal(migraphx::literal(sc, {1}));
mm->add_instruction(migraphx::make_op("undefined"));
auto* body = p.create_module("Loop_3_loop");
body->add_parameter("iteration_num", {migraphx::shape::int64_type});
body->add_parameter("keep_going_inp", {migraphx::shape::bool_type});
auto var = body->add_parameter("b_in", su);
auto ad = body->add_instruction(migraphx::make_op("add"), a, var);
auto sb = body->add_instruction(migraphx::make_op("sub"), a, var);
auto gt = body->add_instruction(migraphx::make_op("greater"), ad, sb);
auto cv = body->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), gt);
auto ad1 = body->add_instruction(migraphx::make_op("add"), sb, sb);
body->add_return({cv, sb, ad, ad1});
auto lp = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 10}}), {max_iter, icond, b}, {body});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), lp);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), lp);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp);
mm->add_return({r0, r2});
auto prog = migraphx::parse_onnx("loop_default_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(loop_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type, {1}};
auto max_iter = mm->add_parameter("max_trip_count", si);
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto icond = mm->add_parameter("keep_going_cond", sc);
migraphx::shape su{migraphx::shape::float_type, {1}};
auto a = mm->add_parameter("a", su);
auto b = mm->add_parameter("b", su);
auto* body = p.create_module("Loop_4_loop");
body->add_parameter("iteration_num", si);
body->add_parameter("keep_going_inp", sc);
auto var = body->add_parameter("b_in", su);
auto ad = body->add_instruction(migraphx::make_op("add"), a, var);
auto sb = body->add_instruction(migraphx::make_op("sub"), a, var);
auto gt = body->add_instruction(migraphx::make_op("greater"), ad, sb);
auto cv = body->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), gt);
auto ad1 = body->add_instruction(migraphx::make_op("add"), sb, sb);
body->add_return({cv, sb, ad, ad1});
auto lp = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 10}}), {max_iter, icond, b}, {body});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), lp);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), lp);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp);
mm->add_return({r0, r2});
auto prog = migraphx::parse_onnx("loop_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(lrn_test)
{
migraphx::program p;
......@@ -2029,9 +2109,9 @@ TEST_CASE(matmul_bmbm_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {5, 2, 3, 6, 7}}}), l0);
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {5, 2, 3, 7, 8}}}), l1);
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bl0, bl1);
auto prog = optimize_onnx("matmul_bmbm_test.onnx");
......@@ -2047,7 +2127,7 @@ TEST_CASE(matmul_bmv_test)
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto bsl1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 7, 1}}}), sl1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, bsl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res);
......@@ -2081,7 +2161,7 @@ TEST_CASE(matmul_vbm_test)
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto bsl0 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5, 1, 7}}}), sl0);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bsl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
......@@ -2294,17 +2374,17 @@ TEST_CASE(onehot_test)
std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1};
auto l_dep = mm->add_literal(migraphx::literal(s_dep, data_dep));
auto gather_out = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), l_dep, l_ind);
auto tr_out =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {2, 0, 1}}}), gather_out);
auto tr_out = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}),
gather_out);
auto off_val = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), l_val);
auto on_val = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l_val);
auto diff = mm->add_instruction(migraphx::make_op("sub"), on_val, off_val);
auto mb_off_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 5, 2}}}), off_val);
auto mb_diff = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 5, 2}}}), diff);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), off_val);
auto mb_diff =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), diff);
auto mul = mm->add_instruction(migraphx::make_op("mul"), tr_out, mb_diff);
auto r = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val);
mm->add_return({r});
......@@ -2453,7 +2533,7 @@ TEST_CASE(prelu_brcst_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1);
migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("prelu"), l0, bl1);
mm->add_return({ret});
......@@ -2469,7 +2549,7 @@ TEST_CASE(quantizelinear_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
......@@ -2494,7 +2574,7 @@ TEST_CASE(quantizelinear_int32_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
......@@ -2524,11 +2604,11 @@ TEST_CASE(quantizelinear_zero_point_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
......@@ -2560,12 +2640,12 @@ migraphx::program make_quantizelinear_axis_prog()
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2);
l2_bcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
......@@ -2848,8 +2928,9 @@ TEST_CASE(reshape_non_standard_test)
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{4, 3, 2};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto tran_x = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1}}}), x);
auto x = mm->add_parameter("x", s);
auto tran_x =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x);
auto cont_x = mm->add_instruction(migraphx::make_op("contiguous"), tran_x);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), cont_x);
auto prog = optimize_onnx("reshape_non_standard_test.onnx");
......@@ -3021,7 +3102,8 @@ TEST_CASE(resize_nonstd_input_test)
std::vector<int> ind = {0, 4};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), inx);
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx);
mm->add_instruction(migraphx::make_op("undefined"));
auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx);
......@@ -3315,12 +3397,10 @@ TEST_CASE(selu_test)
auto x = mm->add_parameter("x", s);
migraphx::shape ls{migraphx::shape::double_type, {1}};
auto la = mm->add_literal({ls, {0.3}});
auto lg = mm->add_literal({ls, {0.25}});
auto mbla =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), la);
auto mblg =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), lg);
auto la = mm->add_literal({ls, {0.3}});
auto lg = mm->add_literal({ls, {0.25}});
auto mbla = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), la);
auto mblg = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lg);
auto sign_x = mm->add_instruction(migraphx::make_op("sign"), x);
auto exp_x = mm->add_instruction(migraphx::make_op("exp"), x);
......@@ -3647,7 +3727,7 @@ TEST_CASE(sub_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", l0->get_shape().lens()}}), l1);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l2);
auto prog = optimize_onnx("sub_bcast_test.onnx");
......@@ -3661,8 +3741,8 @@ TEST_CASE(sub_scalar_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, m1);
auto prog = optimize_onnx("sub_scalar_test.onnx");
......@@ -3778,6 +3858,63 @@ TEST_CASE(tanh_test)
EXPECT(p == prog);
}
TEST_CASE(thresholdedrelu_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {1.0f}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_default_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(thresholdedrelu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3.0f}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(thresholdedrelu_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_int_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(tile_test)
{
migraphx::program p;
......@@ -3806,19 +3943,92 @@ TEST_CASE(tile_test_3x2)
EXPECT(p == prog);
}
TEST_CASE(transpose_default_perm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 3}});
std::vector<int64_t> perm{3, 2, 1, 0};
auto r = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({r});
auto prog = migraphx::parse_onnx("transpose_default_perm_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(transpose_invalid_perm_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("transpose_invalid_perm_test.onnx"); }));
}
TEST_CASE(transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
auto prog = optimize_onnx("transpose_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(topk_attrk_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(migraphx::make_op("topk", {{"k", 2}, {"axis", -1}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_attrk_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(topk_neg_axis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sk{migraphx::shape::int64_type, {1}};
mm->add_literal(migraphx::literal(sk, {3}));
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(
migraphx::make_op("topk", {{"k", 3}, {"axis", -2}, {"largest", 1}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_neg_axis_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(topk_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sk{migraphx::shape::int64_type, {1}};
mm->add_literal(migraphx::literal(sk, {4}));
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(
migraphx::make_op("topk", {{"k", 4}, {"axis", 1}, {"largest", 0}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
......@@ -3837,9 +4047,9 @@ TEST_CASE(transpose_gather_test)
auto ind =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3}}}), data);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), data);
auto tr_ind =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3}}}), ind);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), ind);
int axis = 1;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}),
make_contiguous(tr_data),
......@@ -3964,32 +4174,14 @@ TEST_CASE(where_test)
auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto int_c = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
lc);
auto lccm = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 2, 2}}}), int_c);
auto lxm = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 2, 2}}}), lx);
auto lym = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 2, 2}}}), ly);
auto concat_data = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), lym, lxm);
auto rsp_data =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {32}}}), concat_data);
std::vector<int> offset(16, 16);
std::vector<int> ind(16);
std::iota(ind.begin(), ind.end(), 0);
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 2, 2, 2}};
auto lind = mm->add_literal(migraphx::literal(ind_s, ind));
auto loffset = mm->add_literal(migraphx::literal(ind_s, offset));
auto ins_co = mm->add_instruction(migraphx::make_op("mul"), loffset, lccm);
auto ins_ind = mm->add_instruction(migraphx::make_op("add"), ins_co, lind);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
auto lccm =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lc);
auto lxm =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lx);
auto lym =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), ly);
auto r = mm->add_instruction(migraphx::make_op("where"), lccm, lxm, lym);
mm->add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx");
......
thresholdedrelu_default_test:i

xy"ThresholdedReluthresholdedrelu_default_testZ
x



b
y



B
\ No newline at end of file
topk_attrk_test:
$
datavalindices"TopK*
ktopk_attrk_testZ
data




b
val




b!
indices




B
\ No newline at end of file
transpose_default_perm_test:j

01" Transposetranspose_default_perm_testZ
0




b
1




B
\ No newline at end of file
......@@ -74,7 +74,7 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", lens}}),
input);
}
......@@ -94,14 +94,14 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input);
}
}
......@@ -953,70 +953,70 @@ TEST_CASE(multibroadcast)
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
}
......@@ -1558,10 +1558,9 @@ TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input);
}
TEST_CASE(step_test)
......@@ -1583,4 +1582,28 @@ TEST_CASE(step_test)
}
}
TEST_CASE(unary_scalar_input)
{
migraphx::shape ss{migraphx::shape::half_type};
expect_shape(ss, migraphx::make_op("sin"), ss);
migraphx::shape s{migraphx::shape::float_type, {1}};
expect_shape(s, migraphx::make_op("sin"), s);
}
TEST_CASE(unary_broadcast_input)
{
migraphx::shape ss{migraphx::shape::half_type, {2, 3}, {1, 0}};
migraphx::shape s{migraphx::shape::half_type, {2, 3}};
expect_shape(s, migraphx::make_op("sin"), ss);
}
TEST_CASE(where_broadcast_input)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {3, 0}};
migraphx::shape s2{migraphx::shape::float_type, {2, 2}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}};
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4,7 +4,6 @@ if sys.version_info < (3, 0):
import argparse
import os
import platform
import unittest
import onnx
import onnx.backend.test
......@@ -196,6 +195,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_Tanh*')
backend_test.include(r'.*test_tanh.*')
backend_test.include(r'.*test_thresholdedrelu.*')
backend_test.include(r'.*test_topk.*')
backend_test.include(r'.*test_Topk.*')
backend_test.include(r'.*test_transpose.*')
backend_test.include(r'.*test_unsqueeze.*')
backend_test.include(r'.*test_where*')
......@@ -288,9 +289,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_softplus_example_cpu')
backend_test.exclude(r'test_softsign_cpu')
backend_test.exclude(r'test_softsign_example_cpu')
backend_test.exclude(r'test_thresholdedrelu_cpu')
backend_test.exclude(r'test_thresholdedrelu_default_cpu')
backend_test.exclude(r'test_thresholdedrelu_example_cpu')
backend_test.exclude(r'test_Embedding_cpu')
backend_test.exclude(r'test_Softplus_cpu')
......
import migraphx, array, sys
import migraphx
def test_add_op():
......
......@@ -7,46 +7,29 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
migraphx::instruction_ref
create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction_ref input)
{
auto* mm = p.get_main_module();
auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val);
}
migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
migraphx::program& p,
float max,
float min,
migraphx::instruction_ref input)
static void optimize_prog_int8(migraphx::program& prog)
{
auto* mm = p.get_main_module();
auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min);
max_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
min_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val);
migraphx::run_passes(prog,
{migraphx::simplify_qdq{},
migraphx::eliminate_common_subexpression{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(param_add)
......@@ -71,9 +54,9 @@ TEST_CASE(param_add)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto hp1 = mm->insert_instruction(std::next(p1), migraphx::make_op("convert"), p1);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(std::next(p2), migraphx::make_op("convert"), p2);
auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1);
auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2);
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto res = mm->add_instruction(
migraphx::make_op("convert",
......@@ -130,7 +113,8 @@ TEST_CASE(param_add_sub)
auto p2 = mm->add_parameter("y", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
mm->add_instruction(migraphx::make_op("add"), diff, p1);
auto r = mm->add_instruction(migraphx::make_op("add"), diff, p1);
mm->add_return({r});
return p;
};
......@@ -140,32 +124,21 @@ TEST_CASE(param_add_sub)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto hp1 = mm->insert_instruction(
std::next(p1),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p1);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(
std::next(p2),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p2);
auto hp1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto sum = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hsum);
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hsum);
auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
auto hdiff = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
diff);
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff);
auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
res);
auto r = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res);
mm->add_return({r});
return p;
};
......@@ -174,51 +147,18 @@ TEST_CASE(param_add_sub)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(
std::next(p2),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p2);
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto hsum = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
sum);
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sum);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
auto diff = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hdiff);
mm->add_instruction(migraphx::make_op("add"), diff, p1);
return p;
};
auto create_program_half_all = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto hp1 = mm->insert_instruction(
std::next(p1),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p1);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(
std::next(p2),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
auto hres = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hres);
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hdiff);
auto r = mm->add_instruction(migraphx::make_op("add"), diff, p1);
mm->add_return({r});
return p;
};
......@@ -236,17 +176,70 @@ TEST_CASE(param_add_sub)
auto p2 = create_program_half_sub();
migraphx::quantize_fp16(p1, {"sub"});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_float();
auto p2 = create_program_half_all();
auto create_program_fp16 = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto hp1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto sum = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hsum);
auto hsum1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sum);
auto p3 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto diff = mm->add_instruction(migraphx::make_op("sub"), hsum1, p3);
auto fdiff = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), diff);
auto hdiff1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), fdiff);
auto p4 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto res = mm->add_instruction(migraphx::make_op("add"), hdiff1, p4);
auto r = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res);
mm->add_return({r});
return p;
};
auto create_program_quant_fp16 = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto hp1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
auto hres = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
auto r = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hres);
mm->add_return({r});
return p;
};
auto p0 = create_program_float();
migraphx::run_passes(p0, {migraphx::quantize_fp16_pass{{"all"}}});
EXPECT(p0 == create_program_fp16());
auto p1 = create_program_float();
migraphx::quantize_fp16(p1);
migraphx::run_passes(*p1.get_main_module(), {migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
EXPECT(p1 == create_program_quant_fp16());
}
}
......@@ -308,13 +301,125 @@ TEST_CASE(literal_add)
}
}
TEST_CASE(op_capture)
TEST_CASE(fp16_subgraph)
{
auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index;
(void)args;
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
auto mfp16 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), mul0);
then_mod->add_return({add0, mul0, mfp16});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
auto afp16 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), add1);
else_mod->add_return({mul1, add1, afp16});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
auto r16 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), ret);
mm->add_return({r0, r1, r16});
return p;
};
auto create_fp16_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto hl1 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l1);
auto mhl1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1);
auto hx = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1);
auto fad = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad);
auto hl2 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2);
auto mhl2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2);
auto hy1 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), y);
auto mu = then_mod->add_instruction(migraphx::make_op("mul"), hy1, mhl2);
auto fmu = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), mu);
then_mod->add_return({fad, fmu, mu});
auto* else_mod = p.create_module("If_6_else");
auto hl3 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l3);
auto mhl3 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3);
auto hx2 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
auto mu1 = else_mod->add_instruction(migraphx::make_op("mul"), hx2, mhl3);
auto fmu1 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), mu1);
auto mhl4 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl3);
auto hy = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), y);
auto ad1 = else_mod->add_instruction(migraphx::make_op("add"), hy, mhl4);
auto fad1 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad1);
else_mod->add_return({fmu1, fad1, ad1});
auto iff = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), iff);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), iff);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), iff);
mm->add_return({r0, r1, r2});
return p;
};
auto p1 = create_program();
migraphx::quantize_fp16(p1);
auto p2 = create_fp16_program();
EXPECT(p1 == p2);
}
TEST_CASE(op_capture)
{
auto create_program_float = [] {
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -343,186 +448,105 @@ TEST_CASE(op_capture)
auto pb = mm->add_parameter("b", s2);
auto pc = mm->add_parameter("c", s2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto opb = mm->insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
auto opc = mm->insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
auto opa = mm->add_instruction(migraphx::op::capture{0, test_func}, pa);
auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa);
auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb);
auto opc = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pc);
auto ps = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc);
auto ops = mm->add_instruction(migraphx::op::capture{3, test_func}, ps);
mm->add_instruction(migraphx::make_op("dot"), opa, ops);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 4}}), ps);
mm->add_instruction(migraphx::make_op("dot"), opm, ops);
return p;
};
{
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{};
migraphx::capture_arguments(p, t, {"dot", "convolution"});
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{};
std::size_t param_index = 0;
migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
EXPECT(p == op_capture_p);
}
}
TEST_CASE(dot_float)
TEST_CASE(op_capture_subgraph)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 1.5f}}), pa, pb, pc);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal(sa, vfa));
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, pa);
auto ra = mm->add_instruction(migraphx::make_op("round"), ma);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal(sb, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto fdot = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fdot);
std::vector<float> v_beta(pc->get_shape().elements(), 1.5f);
auto beta = mm->add_literal(migraphx::literal(pc->get_shape(), v_beta));
auto beta_c = mm->add_instruction(migraphx::make_op("mul"), beta, pc);
mm->add_instruction(migraphx::make_op("add"), alpha_ab, beta_c);
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
return p;
};
auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction(migraphx::make_op("dot"), a, b);
then_mod->add_return({out1});
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(dot_double_2args)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto* else_mod = p.create_module("If_6_else");
auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), x, w);
else_mod->add_return({out2});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 1.5f}}), pa, pb);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto create_int8_quantized_prog = [] {
auto create_program_op = [&] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
migraphx::shape sc{migraphx::shape::double_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fpa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, fpa);
auto ra = mm->add_instruction(migraphx::make_op("round"), ma);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
auto fpb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, fpb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto fdot = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fdot);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
alpha_ab);
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto ca = then_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), a);
auto cb = then_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), b);
auto out1 = then_mod->add_instruction(migraphx::make_op("dot"), ca, cb);
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
auto cx = else_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), x);
auto cw = else_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), w);
auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), cx, cw);
else_mod->add_return({out2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
{
auto p = create_program();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{};
std::size_t param_index = 0;
migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
EXPECT(p == qp);
EXPECT(p == op_capture_p);
}
}
TEST_CASE(dot_large_alpha_beta_float)
TEST_CASE(dot_float)
{
auto create_program = [] {
migraphx::program p;
......@@ -534,8 +558,9 @@ TEST_CASE(dot_large_alpha_beta_float)
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.5f}}), pa, pb, pc);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
mm->add_return({r});
return p;
};
......@@ -546,153 +571,103 @@ TEST_CASE(dot_large_alpha_beta_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal(sa, vfa));
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, pa);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta = mm->add_literal(migraphx::literal(sa, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal(sb, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
// quantize parameter c to int32 type
auto qc = mm->insert_instruction(
std::next(pc),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
pc);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2000}, {"beta", 51}}), qa, qb, qc);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(dot_large_alpha_beta_int32)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.0f}}), pa, pb, pc);
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_a);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a);
auto zp_b = mm->add_literal(static_cast<int8_t>(0));
auto scale_b = mm->add_literal(10.0f);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}),
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
mm->add_return({r});
return p;
};
auto create_int8_quantized_prog = [] {
auto create_int8_optimized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto conv_a = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, conv_a);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta =
mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto conv_b = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, conv_b);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2000}, {"beta", 50}}), qa, qb, pc);
mm->add_parameter("c", sc);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto quant_a = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
const std::vector<std::pair<float, float>> quant_params = {
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
auto p = create_program();
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
optimize_prog_int8(p);
auto op = create_int8_optimized_prog();
EXPECT(p == op);
}
TEST_CASE(dot_int32_one_arg)
TEST_CASE(dot_double_2args)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
auto pa = mm->add_parameter("a", s);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.0f}}), pa, pa);
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
mm->add_return({r});
return p;
};
......@@ -700,177 +675,87 @@ TEST_CASE(dot_int32_one_arg)
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
auto pa = mm->add_parameter("a", s);
// add the shift
auto fpa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
std::vector<float> vsa(s.elements(), 1.0f);
auto sfta =
mm->add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, fpa);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
auto q_dot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qa);
auto f_dot = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
q_dot);
std::vector<float> v_alpha(f_dot->get_shape().elements(), 20.0f);
auto new_alpha = mm->add_literal(migraphx::literal{f_dot->get_shape(), v_alpha});
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, f_dot);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
alpha_ab);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(dot_int32)
{
auto create_program = [](bool add_return = false) {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto res = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 5.5f}}), pa, pb, pc);
if(add_return)
{
mm->add_return({res});
}
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0);
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_a);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a);
auto zp_b = mm->add_literal(static_cast<int8_t>(0));
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}),
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r});
return p;
};
auto create_int8_quantized_prog = [](bool add_return = false) {
auto create_int8_optimized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto conv_a = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, conv_a);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta =
mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto conv_b = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, conv_b);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
auto scale_a = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto fr = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fr->get_shape().elements(), 20.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fr);
auto fc = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pc);
std::vector<float> v_beta(fc->get_shape().elements(), 5.5f);
auto beta = mm->add_literal(migraphx::literal(fc->get_shape(), v_beta));
auto beta_c = mm->add_instruction(migraphx::make_op("mul"), beta, fc);
auto f_res = mm->add_instruction(migraphx::make_op("add"), alpha_ab, beta_c);
auto res = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
f_res);
if(add_return)
{
mm->add_return({res});
}
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
auto p_ret = create_program(true);
migraphx::quantize_int8_impl(p_ret, quant_params, {"dot"});
auto qp_ret = create_int8_quantized_prog(true);
EXPECT(p_ret == qp_ret);
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.2f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
EXPECT(p == create_int8_optimized_prog());
}
TEST_CASE(dot_float_convert)
TEST_CASE(dot_half_1arg)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto fpa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 5.5f}}), fpa, pb);
migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s);
auto r =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
mm->add_return({r});
return p;
};
......@@ -878,44 +763,67 @@ TEST_CASE(dot_float_convert)
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_a);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_a, zp_a);
auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a);
auto zp_b = mm->add_literal(static_cast<int8_t>(0));
auto scale_b = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_b);
zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r});
return p;
};
auto create_int8_optimized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), pa, qb);
auto fr = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fr->get_shape().elements(), 10.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
mm->add_instruction(migraphx::make_op("mul"), new_alpha, fr);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
EXPECT(p == create_int8_optimized_prog());
}
TEST_CASE(conv_float)
......@@ -927,7 +835,8 @@ TEST_CASE(conv_float)
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
......@@ -939,119 +848,61 @@ TEST_CASE(conv_float)
migraphx::shape sw{migraphx::shape::float_type, {4, 3, 3, 3}};
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfx(sx.elements(), 0.1f);
auto fx = mm->add_literal(migraphx::literal(sx, vfx));
auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, px);
auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cx);
// quantize parameter b to int8 type
auto insert_loc = std::next(pw);
std::vector<float> vfw(sw.elements(), 0.1f);
auto fw = mm->add_literal(migraphx::literal(sw, vfw));
auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, pw);
auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cw);
auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto f_conv = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
q_conv);
std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
auto adj = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
mm->add_instruction(migraphx::make_op("mul"), adj, f_conv);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}};
std::vector<float> vec(sc.elements(), 100.0f);
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()};
auto d_scale = mm->add_literal(100.0f);
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(conv_int32)
TEST_CASE(conv_float_throw)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape sw{migraphx::shape::int32_type, {4, 3, 3, 3}};
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
// quantize parameter a to int8 type, multiply the scale
auto fpx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
px);
std::vector<float> vfx(sx.elements(), 0.1f);
auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, fpx);
auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cx);
// quantize parameter b to int8 type
auto insert_loc = std::next(pw);
auto fpw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pw);
std::vector<float> vfw(sw.elements(), 0.1f);
auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, fpw);
auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cw);
auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
std::vector<float> v_adj(q_conv->get_shape().elements(), 100.0f);
auto adj = mm->add_literal(migraphx::literal(q_conv->get_shape(), v_adj));
mm->add_instruction(migraphx::make_op("mul"), q_conv, adj);
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
test::throws([&] {
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}});
});
}
TEST_CASE(conv_half)
......@@ -1063,7 +914,8 @@ TEST_CASE(conv_half)
mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
......@@ -1075,63 +927,43 @@ TEST_CASE(conv_half)
migraphx::shape sw{migraphx::shape::half_type, {4, 3, 3, 3}};
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
// quantize parameter a to int8 type, multiply the scale
auto fpx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
px);
std::vector<float> vfx(sx.elements(), 0.1f);
auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, fpx);
auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cx);
// quantize parameter b to int8 type
auto insert_loc = std::next(pw);
auto fpw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pw);
std::vector<float> vfw(sw.elements(), 0.1f);
auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, fpw);
auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cw);
auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto f_conv = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
q_conv);
std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
auto adj = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
auto f_res = mm->add_instruction(migraphx::make_op("mul"), adj, f_conv);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
f_res);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0}));
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
template <class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
TEST_CASE(target_copy)
{
auto run_prog = [](migraphx::program p,
......@@ -1223,7 +1055,8 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_return({r});
return p;
};
......@@ -1232,9 +1065,9 @@ TEST_CASE(int8_quantization_dot)
auto p = create_program();
migraphx::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
m["a"] = migraphx::generate_argument(sa);
m["c"] = migraphx::generate_argument(sc);
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
m["a"] = migraphx::generate_argument(sa, get_hash(std::string("a")));
m["b"] = migraphx::generate_argument(sb, get_hash(std::string("b")));
std::vector<float> quant_result;
migraphx::target ref_t = migraphx::ref::target{};
run_prog(p, ref_t, m, quant_result, true);
......@@ -1272,7 +1105,8 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> v(sx.elements(), 0.5f);
auto input = mm->add_literal(migraphx::literal(sx, v));
auto weights = mm->add_literal(migraphx::literal(sw, v));
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
......@@ -1290,4 +1124,156 @@ TEST_CASE(int8_quantization_conv)
}
}
TEST_CASE(int8_subgraph)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), x, w);
else_mod->add_return({out2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto create_int8_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sout{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
// then submod
auto* then_mod = p.create_module("If_6_if");
auto zp1 = then_mod->add_literal(static_cast<int8_t>(0));
auto s1 = then_mod->add_literal(10.0f);
auto sa = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), s1);
auto zpa = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp1);
auto qa = then_mod->add_instruction(migraphx::make_op("quantizelinear"), a, sa, zpa);
auto sb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot =
then_mod->add_instruction(migraphx::make_op("quant_dot", {{"beta", 0}}), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
// else submod
auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax);
auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw);
auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f);
so1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p1 = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.5f, 0.0f}, {0.6f, 0.0f}, {0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}});
optimize_prog_int8(p1);
auto p2 = create_int8_program();
EXPECT(p1 == p2);
}
TEST_CASE(test_op_capture)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = mm->add_literal(s1, d1);
auto p2 = mm->add_literal(s1, d1);
auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
migraphx::program capture_p = p;
migraphx::target t = migraphx::ref::target{};
std::size_t param_index = 0;
migraphx::run_passes(capture_p,
{migraphx::capture_arguments_pass{{"dot"}, calc, &param_index}});
p.compile(migraphx::ref::target{});
capture_p.compile(migraphx::ref::target{});
auto cap_res = capture_p.eval({}).back();
auto res = p.eval({}).back();
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -668,7 +668,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bual, bl);
......@@ -715,7 +715,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
......@@ -837,7 +837,7 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
auto bubl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 1}}}), ubl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 1}}}), ubl);
mm->add_instruction(migraphx::make_op("dot"), al, bubl);
std::vector<float> gold = {-0.792717,
6.33595,
......@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
......@@ -946,7 +946,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), al);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bal, bl);
......@@ -994,7 +994,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
......@@ -1030,11 +1030,11 @@ TEST_CASE(matmul_mm2)
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3, 5}}}), al);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 5}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), bal, bbl);
std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
......@@ -1132,7 +1132,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 4, 5}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 4, 5}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
......@@ -1192,9 +1192,10 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
......@@ -1219,9 +1220,10 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
......@@ -1246,10 +1248,12 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
......@@ -1302,9 +1306,10 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
std::vector<int> gold = {
......@@ -1328,9 +1333,10 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
std::vector<int> gold = {
......@@ -1354,10 +1360,12 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
std::vector<int> gold = {
......@@ -1446,10 +1454,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
......@@ -1477,10 +1486,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
......@@ -1508,11 +1518,13 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
......@@ -1577,12 +1589,12 @@ TEST_CASE(quant_dot_3args_batch)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
......
#include <iostream>
#include <vector>
#include <cmath>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include "test.hpp"
static auto run_prog(int64_t iter_num, bool cond, int64_t ini_val)
{
migraphx::shape si{migraphx::shape::int64_type};
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto in_iter = mm->add_parameter("iter_num", si);
auto in_cond = mm->add_parameter("ccond", sc);
auto in_val = mm->add_parameter("val", s);
auto* body = p.create_module("loop_module");
auto iter = body->add_parameter("#loop_module_in_0", si);
body->add_parameter("#loop_module_in_1", sc);
auto in_v = body->add_parameter("#loop_module_in_2", s);
std::vector<int64_t> vd = {3};
auto l = body->add_literal(migraphx::literal(si, vd));
auto ad = body->add_instruction(migraphx::make_op("add"), iter, l);
auto val = body->add_instruction(migraphx::make_op("add"), in_v, ad);
auto eq = body->add_instruction(migraphx::make_op("equal"), iter, l);
auto beq = body->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq);
auto neq = body->add_instruction(migraphx::make_op("not"), beq);
body->add_return({neq, val, val});
auto rl = mm->add_instruction(migraphx::make_op("loop", {{"max_iterations", 10}}),
{in_iter, in_cond, in_val},
{body});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rl);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rl);
mm->add_return({r0, r1});
return p;
};
auto p = create_program();
p.compile(migraphx::ref::target{});
migraphx::parameter_map pp;
pp["iter_num"] = migraphx::argument(si, &iter_num);
pp["ccond"] = migraphx::argument(sc, &cond);
pp["val"] = migraphx::argument(s, &ini_val);
auto rets = p.eval(pp);
std::vector<std::vector<int64_t>> res;
for(auto& arg : rets)
{
std::vector<int64_t> vec;
arg.visit([&](auto v) { vec.assign(v.begin(), v.end()); });
res.push_back(vec);
}
return res;
}
TEST_CASE(loop_test1)
{
auto ress = run_prog(10, true, 1);
std::vector<int64_t> gold_last = {19};
EXPECT(ress.front() == gold_last);
std::vector<int64_t> gold_concat = {4, 8, 13, 19, 0, 0, 0, 0, 0, 0};
EXPECT(ress.back() == gold_concat);
}
TEST_CASE(loop_test2)
{
auto ress = run_prog(4, true, 1);
std::vector<int64_t> gold_last = {19};
EXPECT(ress.front() == gold_last);
std::vector<int64_t> gold_concat = {4, 8, 13, 19, 0, 0, 0, 0, 0, 0};
EXPECT(ress.back() == gold_concat);
}
TEST_CASE(loop_test3)
{
auto ress = run_prog(3, true, 1);
std::vector<int64_t> gold_last = {13};
EXPECT(ress.front() == gold_last);
std::vector<int64_t> gold_concat = {4, 8, 13, 0, 0, 0, 0, 0, 0, 0};
EXPECT(ress.back() == gold_concat);
}
TEST_CASE(loop_test4)
{
auto ress = run_prog(5, true, 2);
std::vector<int64_t> gold_last = {20};
EXPECT(ress.front() == gold_last);
std::vector<int64_t> gold_concat = {5, 9, 14, 20, 0, 0, 0, 0, 0, 0};
EXPECT(ress.back() == gold_concat);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment