Commit 2f268bc2 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents f75c5a38 aa7ff911
......@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module)
{
migraphx::program p;
......
......@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
......@@ -109,6 +110,24 @@ int main() {}
)__migraphx__";
// NOLINTNEXTLINE
const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp>
extern "C" {
__global__ void kernel(${type}* p)
{
auto x = *p;
*p = migraphx::implicit_conversion(migraphx::${invoke});
}
}
int main() {}
)__migraphx__";
migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{
return {name, std::make_pair(content.data(), content.data() + content.size())};
......@@ -248,4 +267,66 @@ TEST_CASE(compile_pointwise)
EXPECT(result == output_literal.get_argument());
}
TEST_CASE(compile_math)
{
std::vector<std::string> math_invoke = {
// clang-format off
"abs(x)",
"acos(x)",
"acosh(x)",
"asin(x)",
"asinh(x)",
"atan(x)",
"atanh(x)",
"ceil(x)",
"cos(x)",
"cosh(x)",
"erf(x)",
"exp(x)",
"floor(x)",
"isnan(x)",
"log(x)",
"max(x, x)",
"min(x, x)",
"pow(x, 0)",
"pow(x, x)",
"round(x)",
"rsqrt(x)",
"sin(x)",
"sinh(x)",
"sqrt(x)",
"tan(x)",
"tanh(x)",
"where(true, x, x)",
// clang-format on
};
std::vector<std::string> data_types;
auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options;
options.global = 1024;
options.local = 1024;
options.inputs = {input};
options.output = input;
migraphx::par_for(math_invoke.size() * data_types.size(), 1, [&](auto i) {
const auto& t = data_types[i % data_types.size()];
const auto& invoke = math_invoke[i / data_types.size()];
auto src = migraphx::interpolate_string(math_template, {{"type", t}, {"invoke", invoke}});
auto co = migraphx::gpu::compile_hip_code_object(src, options);
(void)co;
});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any2)
......@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any3)
......@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any4)
......@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any5)
......@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_all_of1)
......@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape())
.bind("pass");
auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......
gathernd_batch_dims_test:
/
data
indicesy"GatherND*
batch_dimsgathernd_batch_dims_testZ
data



Z
indices


b
y


B
\ No newline at end of file
 gathernd_test:q

data
indicesy"GatherND gathernd_testZ
data


Z
indices


b
y

B
\ No newline at end of file
......@@ -1666,6 +1666,35 @@ def gather_elements_axis1_test():
return ([node], [x, i], [y])
@onnx_test
def gathernd_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2])
i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2])
node = onnx.helper.make_node('GatherND',
inputs=['data', 'indices'],
outputs=['y'])
return ([node], [x, i], [y])
@onnx_test
def gathernd_batch_dims_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2])
node = onnx.helper.make_node(
'GatherND',
inputs=['data', 'indices'],
outputs=['y'],
batch_dims=1,
)
return ([node], [x, i], [y])
@onnx_test
def gemm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7])
......@@ -1749,6 +1778,20 @@ def globalavgpool_test():
return ([node], [x], [y])
@onnx_test
def globallppool_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 1, 1])
node = onnx.helper.make_node(
'GlobalLpPool',
inputs=['0'],
outputs=['1'],
)
return ([node], [x], [y])
@onnx_test
def globalmaxpool_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......@@ -2868,6 +2911,32 @@ def lpnormalization_p_error_test():
return ([node], [x], [y])
@onnx_test
def lppool_l1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 3])
node = onnx.helper.make_node('LpPool',
inputs=['x'],
outputs=['y'],
kernel_shape=[3],
p=1)
return ([node], [x], [y])
@onnx_test
def lppool_l2_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 3])
node = onnx.helper.make_node('LpPool',
inputs=['x'],
outputs=['y'],
kernel_shape=[3],
p=2)
return ([node], [x], [y])
@onnx_test
def lrn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24])
......@@ -3109,6 +3178,20 @@ def mean_test():
return ([node], data, [mean])
@onnx_test
def mean_integral_test():
data = [
helper.make_tensor_value_info(str(i), TensorProto.INT32, [2, 2, 2])
for i in range(10)
]
data_names = [str(i) for i in range(10)]
mean = helper.make_tensor_value_info('mean', TensorProto.INT32, [2, 2, 2])
node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"])
return ([node], data, [mean])
@onnx_test
def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......@@ -4345,6 +4428,142 @@ def resize_upsample_pc_test():
return ([node], [X], [Y], [scale_tensor])
@onnx_test
def reversesequence_4D_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2, 2])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=0,
batch_axis=1,
sequence_lens=[2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_batch_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
seq_lens = np.array([1, 2, 3, 4])
seq_lens_tensor = helper.make_tensor(
name="sequence_lens",
data_type=TensorProto.INT64,
dims=seq_lens.shape,
vals=seq_lens.astype(np.int64),
)
arg_seq_lens = helper.make_node(
"Constant",
inputs=[],
outputs=['arg_seq_lens'],
value=seq_lens_tensor,
)
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x', 'arg_seq_lens'],
outputs=['y'],
time_axis=1,
batch_axis=0,
)
return ([arg_seq_lens, node], [x], [y])
@onnx_test
def reversesequence_batch_axis_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=0,
batch_axis=2,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_rank_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_sequence_lens_shape_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
sequence_lens=[4, 3, 2],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_same_axis_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=1,
batch_axis=1,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_time_axis_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2, 3])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=3,
batch_axis=0,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_time_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=0,
batch_axis=1,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def roialign_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8])
......@@ -4381,7 +4600,47 @@ def roialign_test():
@onnx_test
def scatter_test():
def scatter_add_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'ScatterElements',
reduction='add',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test
def scatter_mul_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'ScatterElements',
reduction='mul',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test
def scatter_none_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
......@@ -4390,7 +4649,8 @@ def scatter_test():
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'Scatter',
'ScatterElements',
reduction='none',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
......
globallppool_test:c

01" GlobalLpPoolgloballppool_testZ
0




b
1




B
\ No newline at end of file
lppool_l1_test:q
-
xy"LpPool*
kernel_shape@*
plppool_l1_testZ
x



b
y



B
\ No newline at end of file
lppool_l2_test:q
-
xy"LpPool*
kernel_shape@*
plppool_l2_testZ
x



b
y



B
\ No newline at end of file
mean_integral_test:Ö
*
0
1
2
3
4
5
6
7
8
9mean"Meanmean_integral_testZ
0



Z
1



Z
2



Z
3



Z
4



Z
5



Z
6



Z
7



Z
8



Z
9



b
mean



B
\ No newline at end of file
......@@ -1582,6 +1582,31 @@ TEST_CASE(gather_elements_axis1_test)
EXPECT(p == prog);
}
TEST_CASE(gathernd_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 2}});
mm->add_instruction(migraphx::make_op("gathernd"), l0, l1);
auto prog = optimize_onnx("gathernd_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gathernd_batch_dims_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1}});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), l0, l1);
auto prog = optimize_onnx("gathernd_batch_dims_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_test)
{
migraphx::program p;
......@@ -1704,6 +1729,23 @@ TEST_CASE(globalavgpool_test)
EXPECT(p == prog);
}
TEST_CASE(globallppool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0};
mm->add_instruction(op, input);
auto prog = optimize_onnx("globallppool_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(globalmaxpool_test)
{
migraphx::program p;
......@@ -2596,6 +2638,38 @@ TEST_CASE(lpnormalization_p_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_p_error_test.onnx"); }));
}
TEST_CASE(lppool_l1_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 3, 5}});
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm},
{"padding", {0, 0}},
{"stride", {1}},
{"lengths", {3}},
{"lp_order", 1}}),
l0);
auto prog = optimize_onnx("lppool_l1_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(lppool_l2_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 3, 5}});
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm},
{"padding", {0, 0}},
{"stride", {1}},
{"lengths", {3}},
{"lp_order", 2}}),
l0);
auto prog = optimize_onnx("lppool_l2_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(lrn_test)
{
migraphx::program p;
......@@ -2816,6 +2890,30 @@ TEST_CASE(mean_test)
EXPECT(p == prog);
}
TEST_CASE(mean_integral_test)
{
const std::size_t num_data = 10;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {2, 2, 2}};
auto mean = mm->add_parameter("0", s);
for(std::size_t i = 1; i < num_data; ++i)
{
auto data = mm->add_parameter(std::to_string(i), s);
mean = mm->add_instruction(migraphx::make_op("add"), mean, data);
}
auto div_lit = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {num_data}});
auto divisor =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit);
mean = mm->add_instruction(migraphx::make_op("div"), mean, divisor);
auto prog = optimize_onnx("mean_integral_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(min_test)
{
migraphx::program p;
......@@ -3733,7 +3831,6 @@ TEST_CASE(reshape_non_standard_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{4, 3, 2};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto tran_x =
......@@ -4173,6 +4270,126 @@ TEST_CASE(resize_upsample_pf_test)
EXPECT(p == prog);
}
TEST_CASE(reversesequence_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
int batch_axis = 0;
int time_axis = 1;
migraphx::shape sx{migraphx::shape::float_type, {4, 4}};
auto input = mm->add_parameter("x", sx);
std::vector<int64_t> sequence_lens = {1, 2, 3, 4};
mm->add_literal({{migraphx::shape::int64_type, {4}}, sequence_lens});
int batch_size = sx.lens()[batch_axis];
int time_size = sx.lens()[time_axis];
auto add_slice =
[&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) {
return mm->add_instruction(migraphx::make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b_start, t_start}},
{"ends", {b_end, t_end}}}),
input);
};
auto ret = add_slice(0, 1, 0, time_size);
for(int b = 1; b < batch_size; ++b)
{
auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]);
s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0);
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size);
s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1);
}
ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
mm->add_return({ret});
auto prog = migraphx::parse_onnx("reversesequence_batch_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reversesequence_batch_axis_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_batch_axis_err_test.onnx"); }));
}
TEST_CASE(reversesequence_rank_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_rank_err_test.onnx"); }));
}
TEST_CASE(reversesequence_sequence_lens_shape_err_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("reversesequence_sequence_lens_shape_err_test.onnx"); }));
}
TEST_CASE(reversesequence_same_axis_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_same_axis_err_test.onnx"); }));
}
TEST_CASE(reversesequence_time_axis_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_time_axis_err_test.onnx"); }));
}
TEST_CASE(reversesequence_time_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
int batch_axis = 1;
int time_axis = 0;
migraphx::shape sx{migraphx::shape::float_type, {4, 4}};
auto input = mm->add_parameter("x", sx);
int batch_size = sx.lens()[batch_axis];
int time_size = sx.lens()[time_axis];
std::vector<int64_t> sequence_lens = {4, 3, 2, 1};
auto add_slice =
[&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) {
return mm->add_instruction(migraphx::make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b_start, t_start}},
{"ends", {b_end, t_end}}}),
input);
};
migraphx::instruction_ref ret;
for(int b = 0; b < batch_size - 1; ++b)
{
auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]);
s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0);
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size);
s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
auto s0 = add_slice(batch_size - 1, batch_size, 0, time_size);
ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("reversesequence_time_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(roialign_default_test)
{
migraphx::shape sx{migraphx::shape::float_type, {10, 4, 7, 8}};
......@@ -4233,7 +4450,8 @@ TEST_CASE(round_test)
EXPECT(p == prog);
}
TEST_CASE(scatter_test)
// the ScatterElements op has 3 reduction modes, which map to separate reference ops
migraphx::program create_scatter_program(const std::string& scatter_mode, int axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -4242,10 +4460,30 @@ TEST_CASE(scatter_test)
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}});
auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2;
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
auto r = mm->add_instruction(migraphx::make_op(scatter_mode, {{"axis", axis}}), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx");
return p;
}
TEST_CASE(scatter_add_test)
{
migraphx::program p = create_scatter_program("scatter_add", -2);
auto prog = migraphx::parse_onnx("scatter_add_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatter_mul_test)
{
migraphx::program p = create_scatter_program("scatter_mul", -2);
auto prog = migraphx::parse_onnx("scatter_mul_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatter_none_test)
{
migraphx::program p = create_scatter_program("scatter_none", -2);
auto prog = migraphx::parse_onnx("scatter_none_test.onnx");
EXPECT(p == prog);
}
......
reversesequence_rank_err_test:v
3
xy"ReverseSequence*
sequence_lens@@@@reversesequence_rank_err_testZ
x

b
y

B
\ No newline at end of file
"reversesequence_same_axis_err_test:
X
xy"ReverseSequence*
batch_axis*
sequence_lens@@@@*
time_axis"reversesequence_same_axis_err_testZ
x


b
y


B
\ No newline at end of file
,reversesequence_sequence_lens_shape_err_test:‹
1
xy"ReverseSequence*
sequence_lens@@@ ,reversesequence_sequence_lens_shape_err_testZ
x


b
y


B
\ No newline at end of file
 scatter_test:
9
scatter_add_test:
V
data
indices
updatey"Scatter*
axis scatter_testZ
updatey"ScatterElements*
axis*
reduction"addscatter_add_testZ
data


......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment