Unverified Commit c4cee345 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Merge branch 'develop' into rocblas_fp8

parents c40a39c3 eafd55de
  scatternd_invalid_reduction_test:à
D
data
indices
updatesoutput" ScatterND*
reduction"invalid  scatternd_invalid_reduction_testZ
data



Z
indices



Z
updates



b
output



B
\ No newline at end of file
 scatternd_max_test:
@
data
indices
updatesoutput" ScatterND*
reduction"maxscatternd_max_testZ
data



Z
indices



Z
updates



b
output



B
\ No newline at end of file
 scatternd_min_test:Î
@
data
indices
updatesoutput" ScatterND*
reduction"min scatternd_min_testZ
data



Z
indices



Z
updates



b
output



B
\ No newline at end of file
 unique_dynamic_sorted_3D_test:Ö
?
XYindicesinverse_indicescounts"Unique*
sorted unique_dynamic_sorted_3D_testZ
X



b
Y

b
indices

b
inverse_indices

@b
counts

B
\ No newline at end of file
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -1686,6 +1686,252 @@ TEST_CASE(qlinearadd_bcast_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_1d_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_1d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {
-31, 51, 125, 30, -17, -125, 121, -19, -13, 52, 18, -70, 97, 15, 56, 42,
-65, -26, 40, -109, -70, 83, 110, -94, 34, 70, 5, -23, -60, -68, 19, 48,
-113, 3, -44, 20, -99, -103, -49, -38, 122, 75, 38, -7, -65, -56, 96, 99,
50, -27, -114, 49, -65, 105, -3, 54, 8, 38, -81, -46, -86, -46, -104, 36,
22, -51, 48, 59, -116, 6, 93, 16, -111, 98, 51, -87, -111, -74, -39, 7,
107, 115, 59, 60, -66, -14, -106, -23, 119, -122, -51, -100, 26, 125, 45, 90};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 32}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
26, 104, 94, 22, -55, 14, 67, 0, 36, 51, -10, 29, 72, 52, 65, 5,
-30, 23, -19, -74, 23, 112, 24, -14, 68, 54, 7, -26, -48, -8, 50, -39,
-4, 4, -24, -85, -60, -28, 58, 114, 72, 31, -20, -44, 36, 114, 90, 28,
-54, -16, 8, 36, 67, 42, 47, 39, -6, -48, -50, -50, -59, -18, 2, 15,
70, -13, -39, 66, 71, -32, 9, 90, -2, -83, -76, -40, 0, 73, 127, 103,
75, 13, -24, -44, -48, 64, 15, -70, -60, -21, 92, 101, 84};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {84, -73, 117, -2, -97, 72, 67, 27, 1, -44, 110, 51,
9, 7, 58, 113, -34, 34, 124, -20, 6, 66, 68, 98,
31, -84, 25, 101, -69, -100, -68, 116, 33, -121, 78, 49,
102, -86, 65, 69, -87, -89, 16, -125, 51, -54, -86, 79};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {4, 127, 127, -41, 127, 127, -6, 125, 127,
76, 127, 127, 32, 78, 127, -128, -128, 127,
-44, -37, 127, -117, -62, 37, -128, -128, -81};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_ceil_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_ceil_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<uint8_t> data_x = {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32};
migraphx::shape s_x{migraphx::shape::uint8_type, {1, 1, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {120, 150, 240, 255};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_dilations_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_dilations_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 1, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {108, 112, 124, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_pads_count_include_pad_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_pads_count_include_pad_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {-30, 50, 91, -87, -21, -113, -16, 6, -128, 104, 82, -126,
54, 41, -71, 62, -11, -111, 13, 104, -43, -48, 30, 85,
-62, -33, -27, -114, 32, -17, 30, -26, -18, 15, 17, 100,
-122, 115, 84, -34, -86, 82, 102, -117, -91, -105, 112, 91};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
15, 43, 94, 62, 34, -16, 4, -31, 10, -6, 29, -13, -67, -45, 43, 27, 4, -83,
-21, -3, -6, 15, -3, 0, -9, 71, 78, 83, 3, -4, 62, 85, 45, 50, 27, 66,
26, -36, -29, 35, 97, 90, 2, -86, -62, 73, 127, 127, -32, -128, -128, -24, 83, 74,
-9, -63, -45, -35, 20, 1, 15, -12, -11, -72, -44, -46, 50, 40, 57, 25, 34, 18,
22, 30, 40, 105, 97, 88, -46, 26, 83, 127, 125, 69, -94, 24, 127, 127, 116, 4,
-128, -83, 83, 127, 127, -1, -66, -79, 40, 124, 127, 18, -19, -77, -15, 86, 127, 83};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_same_lower_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_same_lower_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<uint8_t> data_x = {195, 102, 250, 61, 222, 6, 243, 218, 230, 105, 36, 116,
194, 31, 113, 85, 126, 204, 80, 38, 115, 167, 221, 67,
69, 140, 11, 209, 136, 120, 39, 96, 29, 5, 167, 40,
58, 51, 157, 179, 244, 149, 76, 243, 126, 144, 192, 199};
migraphx::shape s_x{migraphx::shape::uint8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {195, 148, 176, 156, 208, 131, 150, 193, 226, 141, 98, 153,
212, 140, 71, 88, 126, 165, 142, 59, 120, 153, 168, 102,
92, 123, 135, 127, 102, 116, 78, 89, 29, 17, 86, 104,
44, 36, 95, 136, 151, 126, 108, 164, 185, 166, 140, 178};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_same_upper_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_same_upper_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {-61, 102, -6, 61, -34, 6, -13, -38, -26, 105, 36, 116,
-62, 31, 113, 85, 126, -52, 80, 38, 115, -89, -35, 67,
69, -116, 11, -47, -120, 120, 39, 96, 29, 5, -89, 40,
58, 51, -99, -77, -12, -107, 76, -13, 126, -112, -64, -57};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
-58, -20, -62, -41, -38, 3, -14, 14, -40, 78, 111, 127, -95, 80, 127, 106,
-14, -112, 11, 41, -74, -128, -66, -44, -88, -37, -14, -15, -64, 95, 71, 127,
8, -128, -128, -101, -69, -104, -120, -128, -116, -128, -93, -128, -50, -128, -128, -128};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_strides_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_strides_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {
84, -73, 117, -2, -97, 72, 67, 27, 1, -44, 110, 51, 9, 7, 58, 113,
-34, 34, 124, -20, 6, 66, 68, 98, 31, -84, 25, 101, -69, -100, -68, 116,
33, -121, 78, 49, 102, -86, 65, 69, -87, -89, 16, -125, 51, -54, -86, 79,
-112, -37, -6, 74, 118, -75, -41, 52, 101, -22, -28, -92, -59, -128, 32, 78,
-20, 121, 11, -107, -92, -31, 81, 117, -55, -3, 80, 119, 126, -98, -11, 52,
-4, -66, 37, -57, -16, -33, -12, 100, 55, 2, 27, 62, -15, 64, -74, -21,
-123, 22, -45, 12, 30, 24, 20, 120, -36, -102, -75, -39, -76, 55, 74, -120,
103, 67, -80, -89, -112, 36, 69, 98, 110, -82, 60, 119, 98, 88, 5, 42,
-88, -86, -58, -33, 93, 80, -57, -56, 87, 7, -4, 114, -73, -91, -12, -123,
96, -99, -31, -99, 85, 34, -126, 106, 88, 126, -60, 14, 75, -117, -15, 6,
55, -14, 117, -87, -75, -50, -85, 54, 70, 125, 74, -100, 25, -112, 74, -66,
-116, -102, 1, -75, -107, 83, -120, -66, 57, 29, 62, -45, -103, -56, 90, -53};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 8, 8}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {24, 37, 10, 17, 12, 12, -13, -1, 14, -10, 7, -19};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_3d_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_3d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {
-61, 102, -6, 61, -34, 6, -13, -38, -26, 105, 36, 116, -62, 31, 113, 85, 126,
-52, 80, 38, 115, -89, -35, 67, 69, -116, 11, -47, -120, 120, 39, 96, 29, 5,
-89, 40, 58, 51, -99, -77, -12, -107, 76, -13, 126, -112, -64, -57, 99, -54, 27,
99, 126, -46, -7, 109, 17, 77, 94, -92, 84, -92, 48, 71, 45, -102, 95, 118,
24, 13, -70, 33, 35, -60, 102, 81, 34, 108, -79, 14, -42};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 3, 3, 3}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {56, 114, 49, 39, 32, 127, 3, 45, -4, -13, 8, 22,
-35, -98, 76, 15, 127, 67, 100, 20, 127, 84, 64, 68};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_notset_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_notset_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 1, 5, 5}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {22};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_nt_cip_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool
auto p = migraphx::parse_onnx("qlinearaveragepool_nt_cip_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<uint8_t> data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::shape s_x{migraphx::shape::uint8_type, {1, 1, 5, 5}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {18};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearconv_test)
{
// https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
......@@ -1819,6 +2065,35 @@ TEST_CASE(qlinearglobalavgpool_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearleakyrelu_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid
migraphx::program p = migraphx::parse_onnx("qlinearleakyrelu_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape x{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_x = {
-128, -124, -120, -116, -112, -108, -104, -100, -96, -92, -88, -84, -80, -76, -72, -68,
-64, -60, -56, -52, -48, -44, -40, -36, -32, -28, -24, -20, -16, -12, -8, -4,
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60,
64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
-128, -126, -122, -118, -113, -109, -104, -100, -96, -91, -87, -82, -78, -74, -69, -65,
-60, -56, -52, -47, -43, -38, -34, -30, -25, -21, -16, -12, -8, -3, 1, 6,
10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70,
74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126, 127, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearmatmul_1D_test)
{
migraphx::program p = migraphx::parse_onnx("qlinearmatmul_1D_test.onnx");
......@@ -1970,6 +2245,36 @@ TEST_CASE(qlinearmul_bcast_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearsigmoid_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid
migraphx::program p = migraphx::parse_onnx("qlinearsigmoid_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape x{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_x = {
-128, -124, -120, -116, -112, -108, -104, -100, -96, -92, -88, -84, -80, -76, -72, -68,
-64, -60, -56, -52, -48, -44, -40, -36, -32, -28, -24, -20, -16, -12, -8, -4,
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60,
64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {-128, -127, -127, -127, -127, -127, -126, -126, -126, -125, -125,
-124, -123, -122, -120, -119, -117, -114, -112, -108, -104, -99,
-94, -87, -80, -71, -62, -51, -39, -27, -13, 1, 15,
29, 43, 56, 69, 81, 92, 101, 110, 117, 124, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
......@@ -2430,67 +2735,6 @@ TEST_CASE(softsign_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
{
// input data filled with values 1 to nelements
......@@ -2616,4 +2860,131 @@ TEST_CASE(tril_row_one_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> x{2, 1, 1, 3, 4, 3};
std::vector<float> y_gold = {1, 2, 3, 4};
std::vector<size_t> y_idx_gold = {1, 0, 3, 4};
std::vector<size_t> x_idx_gold = {1, 0, 0, 2, 3, 2};
std::vector<size_t> y_ct_gold = {2, 1, 2, 1};
migraphx::shape s{migraphx::shape::float_type, {x.size()}};
migraphx::parameter_map pm;
pm["X"] = migraphx::argument(s, x.data());
auto result = p.eval(pm);
std::vector<float> yvec;
result[0].visit([&](auto out) { yvec.assign(out.begin(), out.end()); });
EXPECT(yvec == y_gold);
std::vector<size_t> y_idx_vec;
result[1].visit([&](auto out) { y_idx_vec.assign(out.begin(), out.end()); });
EXPECT(y_idx_vec == y_idx_gold);
std::vector<size_t> x_idx_vec;
result[2].visit([&](auto out) { x_idx_vec.assign(out.begin(), out.end()); });
EXPECT(x_idx_vec == x_idx_gold);
std::vector<size_t> y_ct_vec;
result[3].visit([&](auto out) { y_ct_vec.assign(out.begin(), out.end()); });
EXPECT(y_ct_vec == y_ct_gold);
}
TEST_CASE(unique_dynamic_unsorted_test)
{
migraphx::program p = migraphx::parse_onnx("unique_dynamic_unsorted_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> x{2, 1, 1, 3, 4, 3};
std::vector<float> y_gold = {2, 1, 3, 4};
std::vector<size_t> y_idx_gold = {0, 1, 3, 4};
std::vector<size_t> x_idx_gold = {0, 1, 1, 2, 3, 2};
std::vector<size_t> y_ct_gold = {1, 2, 2, 1};
migraphx::shape s{migraphx::shape::float_type, {x.size()}};
migraphx::parameter_map pm;
pm["X"] = migraphx::argument(s, x.data());
auto result = p.eval(pm);
std::vector<float> yvec;
result[0].visit([&](auto out) { yvec.assign(out.begin(), out.end()); });
EXPECT(yvec == y_gold);
std::vector<size_t> y_idx_vec;
result[1].visit([&](auto out) { y_idx_vec.assign(out.begin(), out.end()); });
EXPECT(y_idx_vec == y_idx_gold);
std::vector<size_t> x_idx_vec;
result[2].visit([&](auto out) { x_idx_vec.assign(out.begin(), out.end()); });
EXPECT(x_idx_vec == x_idx_gold);
std::vector<size_t> y_ct_vec;
result[3].visit([&](auto out) { y_ct_vec.assign(out.begin(), out.end()); });
EXPECT(y_ct_vec == y_ct_gold);
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2202,7 +2202,8 @@ TEST_CASE(pooling_shape0)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1}}}),
{"lengths", {1}},
{"dilations", {1}}}),
input);
}
......@@ -2215,7 +2216,8 @@ TEST_CASE(pooling_shape1)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
{"lengths", {1, 1}},
{"dilations", {1, 1}}}),
input);
}
......@@ -2229,6 +2231,7 @@ TEST_CASE(pooling_shape2)
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2243,6 +2246,7 @@ TEST_CASE(pooling_shape3)
{"padding", {2, 2}},
{"stride", {3, 3}},
{"lengths", {3, 3}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2254,6 +2258,63 @@ TEST_CASE(pooling_shape4)
tiny_input);
}
TEST_CASE(pooling_shape5)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}},
{"dilations", {2, 2}}}),
input);
}
TEST_CASE(pooling_shape6)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {2, 2}},
{"lengths", {1, 1}},
{"dilations", {2, 2}}}),
input);
}
TEST_CASE(pooling_shape7)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"dilations", {3, 3}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(pooling_shape8)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {2, 2}},
{"stride", {1, 1}},
{"lengths", {3, 3}},
{"dilations", {2, 2}}}),
input);
}
TEST_CASE(pooling_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}};
......@@ -2261,7 +2322,8 @@ TEST_CASE(pooling_dyn_shape0)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1}}}),
{"lengths", {1}},
{"dilations", {1}}}),
input);
}
......@@ -2274,7 +2336,8 @@ TEST_CASE(pooling_dyn_shape1)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
{"lengths", {1, 1}},
{"dilations", {1, 1}}}),
input);
}
......@@ -2288,6 +2351,7 @@ TEST_CASE(pooling_dyn_shape2)
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2302,7 +2366,8 @@ TEST_CASE(pooling_dyn_shape3)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
{"lengths", {1, 1}},
{"dilations", {1, 1}}}),
input);
}
......@@ -2317,6 +2382,7 @@ TEST_CASE(pooling_dyn_shape4)
{"padding", {2, 2}},
{"stride", {3, 3}},
{"lengths", {3, 3}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -4100,6 +4166,40 @@ TEST_CASE(test_squeeze_wrong_axis)
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unique_axis_invalid)
{
migraphx::shape x_shape{migraphx::shape::float_type, {10, 4, 3}};
throws_shape(migraphx::make_op("unique", {{"axis", -1}}), x_shape);
}
TEST_CASE(test_unique_axis_negative)
{
migraphx::shape x_shape{migraphx::shape::float_type, {10, 4, 3}};
std::vector<migraphx::shape::dynamic_dimension> y_dims{{1, 10}, {4, 4}, {3, 3}};
std::vector<migraphx::shape::dynamic_dimension> idx_dims{{1, 10}};
std::vector<migraphx::shape> y_dyn_shape{{migraphx::shape::float_type, y_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims}};
expect_shape(y_dyn_shape, migraphx::make_op("unique", {{"axis", -3}}), x_shape);
}
TEST_CASE(test_unique_axis_none)
{
migraphx::shape x_shape{migraphx::shape::half_type, {10, 4, 3}};
std::vector<migraphx::shape::dynamic_dimension> y_dims{{1, 120}};
std::vector<migraphx::shape::dynamic_dimension> idx_dims{{1, 120}};
std::vector<migraphx::shape> y_dyn_shape{{migraphx::shape::half_type, y_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims}};
expect_shape(y_dyn_shape, migraphx::make_op("unique"), x_shape);
}
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
......
......@@ -190,7 +190,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(
r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_cpu'
)
backend_test.exclude(r'test_qlinearconv_cpu')
backend_test.exclude(r'test_qlinearmatmul_2D_cpu')
backend_test.exclude(r'test_qlinearmatmul_3D_cpu')
backend_test.exclude(r'test_range_float_type_positive_delta_expanded_cpu')
......@@ -576,6 +575,8 @@ def disabled_tests_onnx_1_9_0(backend_test):
backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest
# MaxPool dialtion is partially supported on GPU by a workaround
# But these tests require too large allocations to work properly
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu')
......@@ -633,8 +634,6 @@ def disabled_tests_onnx_1_11_0(backend_test):
# from OnnxBackendNodeModelTest
backend_test.exclude(r'test_roialign_aligned_false_cpu')
backend_test.exclude(r'test_roialign_aligned_true_cpu')
backend_test.exclude(r'test_scatternd_add_cpu')
backend_test.exclude(r'test_scatternd_multiply_cpu')
# errors
# from OnnxBackendNodeModelTest
......@@ -743,8 +742,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
r'test_reduce_sum_square_negative_axes_keepdims_example_cpu')
backend_test.exclude(
r'test_reduce_sum_square_negative_axes_keepdims_random_cpu')
backend_test.exclude(r'test_scatternd_max_cpu')
backend_test.exclude(r'test_scatternd_min_cpu')
# errors
# from OnnxBackendNodeModelTest
......
......@@ -35,12 +35,13 @@ TEST_CASE(avgpool_rank3_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
......@@ -54,6 +55,103 @@ TEST_CASE(avgpool_rank3_test)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_rank3_dil_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {2};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.35, 0.15, 0.85, 0.3, 0.1, 0.65};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_rank3_dil_test2)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {3};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.2, 0.45, 0.35};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_rank3_pad_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {1};
op.stride = {1};
op.dilations = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_rank3_pad_dil_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {1};
op.stride = {1};
op.dilations = {3};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.2, 0.2, 0.9, 0.45, 0.5, 0.1, 0.35, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_dyn_test)
{
// Dynamic input, no padding
......@@ -65,7 +163,8 @@ TEST_CASE(avgpool_dyn_test)
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}}}),
{"stride", {1}},
{"dilations", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
......@@ -82,7 +181,7 @@ TEST_CASE(avgpool_dyn_test)
TEST_CASE(avgpool_dyn_pad_test)
{
// Dynamic input with explicit padding/
// Dynamic input with explicit padding
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 3}, {3, 3}, {4, 4}}};
......@@ -91,7 +190,8 @@ TEST_CASE(avgpool_dyn_pad_test)
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {2}},
{"padding", {1}},
{"stride", {1}}}),
{"stride", {1}},
{"dilations", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
......@@ -158,7 +258,8 @@ TEST_CASE(avgpool_dyn_auto_pad_1d_test)
// padding added will be {1, 0} to make output
// the same size as input
{"padding_mode", migraphx::op::padding_mode_t::same_lower},
{"stride", {1}}}),
{"stride", {1}},
{"dilations", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
......@@ -171,8 +272,8 @@ TEST_CASE(avgpool_dyn_auto_pad_1d_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold{0.3, 0.25, 0.3, 0.25,
0.8, 0.65, 0.7, 0.5,
std::vector<float> gold{0.3, 0.25, 0.3, 0.25,
0.8, 0.65, 0.7, 0.5,
0.1, 0.4, 0.4, 0.35};
// clang-format on
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
......@@ -190,7 +291,8 @@ TEST_CASE(avgpool_dyn_pad_ceil_test)
{"lengths", {2, 3}},
{"padding", {1, 2}},
{"ceil_mode", true},
{"stride", {1, 1}}}),
{"stride", {1, 1}},
{"dilations", {1, 1}}}),
x);
p.compile(migraphx::make_target("ref"));
......@@ -219,12 +321,13 @@ TEST_CASE(avgpool_rank3_stride2_test)
{
// 1D case 2, stride 2
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {1};
op.stride = {2};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {1};
op.stride = {2};
op.dilations = {1};
// clang-format off
std::vector<float> data{1.6321, -2.4186, 0.2239, -1.4232,
......@@ -252,12 +355,13 @@ TEST_CASE(avgpool_rank5_test)
{
// 3D, input is 5D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {1, 1, 1};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {1, 1, 1};
op.dilations = {1, 1, 1};
std::vector<float> data{
-0.179, -1.756, 0.651, 1.955, 1.87, -0.604, 0.247, 0.449, -0.137, 1.187, 1.593,
......@@ -423,13 +527,14 @@ TEST_CASE(lppool_l1_norm_test)
{
// L1 norm test
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.lp_order = 1;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {1};
op.lp_order = 1;
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
......@@ -449,13 +554,14 @@ TEST_CASE(lppool_l1_norm_test)
// {
// // padding too large for kernel size
// migraphx::program p;
// auto* mm = p.get_main_module();
// auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 5}};
// auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
// op.lengths = {3};
// op.padding = {2};
// op.stride = {1};
// op.lp_order = 1;
// auto* mm = p.get_main_module();
// auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 5}};
// auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
// op.lengths = {3};
// op.padding = {2};
// op.stride = {1};
// op.dilations = {1};
// op.lp_order = 1;
// std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7};
// auto l0 = mm->add_literal(migraphx::literal{s, data});
......@@ -468,13 +574,14 @@ TEST_CASE(lppool_l2_norm_test)
{
// L2 norm test
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.lp_order = 2;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {1};
op.lp_order = 2;
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
......@@ -506,7 +613,8 @@ TEST_CASE(lppool_dyn_test)
{{"mode", migraphx::op::pooling_mode::lpnorm},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}}}),
{"stride", {1}},
{"dilations", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
......@@ -571,7 +679,8 @@ TEST_CASE(maxpool_test)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {2, 2}},
{"lengths", {3, 2}}}),
{"lengths", {3, 2}},
{"dilations", {1, 1}}}),
al);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
......@@ -599,7 +708,8 @@ TEST_CASE(maxpool_pad_test)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1, 1}},
{"stride", {2, 2}},
{"lengths", {3, 2}}}),
{"lengths", {3, 2}},
{"dilations", {1, 1}}}),
al);
// * * * * * * * *
......@@ -620,12 +730,13 @@ TEST_CASE(maxpool_rank3_test0)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
......@@ -643,12 +754,13 @@ TEST_CASE(maxpool_rank3_test1)
{
// 1D case 2, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
op.dilations = {1};
std::vector<float> data{0.4975, -0.1226, -0.0405, -0.2861, -0.1227, -0.6186, -0.9618,
0.6022, -0.1912, 1.1925, 0.5493, 0.1692, -0.8039, -1.0281,
......@@ -664,6 +776,55 @@ TEST_CASE(maxpool_rank3_test1)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank3_test2)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {2};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.2, 0.9, 0.5, 0.1, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank3_test4)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {1};
op.stride = {1};
op.dilations = {3};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.3, 0.2, 0.9, 0.8, 0.5, 0.1, 0.6, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank3_ceil_test)
{
// 1D case 2, input is 3D, ceil mode
......@@ -674,6 +835,7 @@ TEST_CASE(maxpool_rank3_ceil_test)
op.lengths = {2};
op.padding = {0};
op.stride = {2};
op.dilations = {1};
op.ceil_mode = true;
// clang-format off
......@@ -702,12 +864,13 @@ TEST_CASE(maxpool_rank5_test)
{
// 3D, input is 5D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {2, 2, 2};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {2, 2, 2};
op.dilations = {1, 1, 1};
std::vector<float> data{
-2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, 0.3405,
......@@ -741,7 +904,8 @@ TEST_CASE(maxpool_dyn_test)
{{"mode", migraphx::op::pooling_mode::max},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}}}),
{"stride", {1}},
{"dilations", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
......@@ -755,3 +919,29 @@ TEST_CASE(maxpool_dyn_test)
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_dyn_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}},
{"dilations", {2}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.2, 0.9, 0.5, 0.1, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_max_test_1)
{
// r=1, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 3, 1, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4, 9, 6, 7, 12};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_2)
{
// r=2, q=2, k=2
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 2, 3, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_3)
{
// r=3, q=3, k=3
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {2, 1}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_4)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 2};
std::vector<float> upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 5, 5, 5, 6, 6, 7, 8, 8, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 2, 3, 3, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_duplicate_idx)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0};
std::vector<float> upd_vec{5, 5, 5, 5, 2, 2, 2, 2, 7, 7, 7, 7, 4, 4, 4, 4,
1, 1, 1, 1, 6, 6, 6, 6, 3, 3, 3, 3, 8, 8, 8, 8};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 5, 5, 5, 6, 6, 7, 8, 8, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_min_test_1)
{
// r=1, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 3, 1, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 3, 3, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_2)
{
// r=2, q=2, k=2
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 3, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_3)
{
// r=3, q=3, k=3
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {2, 1}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4, 5, 6, 7, 1};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_4)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 2};
std::vector<float> upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4, 5, 6, 6, 6, 7, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 3, 3,
4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_duplicate_idx)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0};
std::vector<float> upd_vec{5, 5, 5, 5, 2, 2, 2, 2, 7, 7, 7, 7, 4, 4, 4, 4,
1, 1, 1, 1, 6, 6, 6, 6, 3, 3, 3, 3, 8, 8, 8, 8};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <optional>
#include <test.hpp>
namespace {
migraphx::program
create_program(const migraphx::shape& data_shape, int64_t sorted, std::optional<int64_t> axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("X", data_shape);
auto op = axis ? migraphx::make_op("unique", {{"axis", *axis}, {"sorted", sorted}})
: migraphx::make_op("unique", {{"sorted", sorted}});
auto r = mm->add_instruction(op, data);
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), r);
auto r3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), r);
mm->add_return({r0, r1, r2, r3});
return p;
};
template <typename T>
auto run_program(T& data,
const migraphx::shape& data_shape,
int sorted,
std::optional<int64_t> axis = std::nullopt)
{
auto p = create_program(data_shape, sorted, axis);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(data_shape, data.data());
auto rets = p.eval(pp);
std::vector<typename std::remove_reference_t<decltype(data)>::value_type> y;
rets[0].visit([&](auto v) { y.assign(v.begin(), v.end()); });
std::vector<int64_t> y_idx;
rets[1].visit([&](auto v) { y_idx.assign(v.begin(), v.end()); });
std::vector<int64_t> x_rev_idx;
rets[2].visit([&](auto v) { x_rev_idx.assign(v.begin(), v.end()); });
std::vector<int64_t> y_ct;
rets[3].visit([&](auto v) { y_ct.assign(v.begin(), v.end()); });
return std::make_tuple(y, y_idx, x_rev_idx, y_ct);
}
} // namespace
// sorted. single entry
TEST_CASE(unique_sorted_single_entry_test)
{
std::vector<int> data = {2};
int64_t axis = 0;
int64_t sorted = 1;
std::vector<size_t> lens = {1};
migraphx::shape data_shape{migraphx::shape::int32_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<int> gold_val = {2};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1};
EXPECT(ct == gold_ct);
}
// unsorted. single entry
TEST_CASE(unique_unsorted_single_entry_test)
{
std::vector<float> data = {3.33};
int64_t axis = -1;
int64_t sorted = 0;
std::vector<size_t> lens = {1};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {3.33};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1};
EXPECT(ct == gold_ct);
}
// case 2 sorted. all unique input..
TEST_CASE(unique_sorted_all_unique_test)
{
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9};
int64_t axis = 0;
int64_t sorted = 1;
std::vector<size_t> lens = {5};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {1.9, 2.1, 2.3, 2.4, 2.5};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {4, 0, 1, 2, 3};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {1, 2, 3, 4, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1, 1, 1, 1, 1};
EXPECT(ct == gold_ct);
}
// case 3 unsorted. all unique input
TEST_CASE(unique_unsorted_all_unique_test)
{
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9};
int64_t axis = 0;
int64_t sorted = 0;
std::vector<size_t> lens = {5};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {2.1, 2.3, 2.4, 2.5, 1.9};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 1, 2, 3, 4};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 1, 2, 3, 4};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1, 1, 1, 1, 1};
EXPECT(ct == gold_ct);
}
// case 4 sorted (with dup entries)
TEST_CASE(unique_sorted_dupes_test)
{
std::vector<double> data = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.5};
int64_t axis = 0;
int64_t sorted = 1;
std::vector<size_t> lens = {8};
migraphx::shape data_shape{migraphx::shape::double_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<double> gold_val = {1.9, 2.1, 2.3, 2.4, 2.5};
EXPECT(y == gold_val);
std::vector<int64_t> gold_ct = {1, 1, 2, 1, 3};
EXPECT(ct == gold_ct);
}
// case 5 unsorted (with dup entries)
TEST_CASE(unique_unsorted_dupes_test)
{
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.1};
int64_t axis = -1;
int64_t sorted = 0;
std::vector<size_t> lens = {8};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {2.1, 2.3, 2.4, 2.5, 1.9};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 1, 2, 3, 4};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 1, 2, 3, 4, 3, 1, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 2, 1, 2, 1};
EXPECT(ct == gold_ct);
}
TEST_CASE(unique_3D_no_axis_test)
{
// sorted 3D (with dup entries). no axis
int sorted = 1;
std::vector<double> data_3d = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.5};
std::vector<size_t> lens = {2, 2, 2}; // 3D data. type double
migraphx::shape data_shape{migraphx::shape::double_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data_3d, data_shape, sorted);
std::vector<double> gold_val = {1.9, 2.1, 2.3, 2.4, 2.5};
EXPECT(y == gold_val);
std::vector<int64_t> gold_ct = {1, 1, 2, 1, 3};
EXPECT(ct == gold_ct);
}
TEST_CASE(unique_3D_no_axis_unsorted_test)
// unsorted 3D (with dup entries). no axis
{
int sorted = 0;
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.1};
std::vector<size_t> lens = {2, 1, 4}; // 3D data. type float
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted);
std::vector<float> gold_val = {2.1, 2.3, 2.4, 2.5, 1.9};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 1, 2, 3, 4};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 1, 2, 3, 4, 3, 1, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 2, 1, 2, 1};
EXPECT(ct == gold_ct);
}
// unique integer sub-tensors: sorted (with dup entries)
TEST_CASE(unique_subtensors_sorted_test)
{
/*
input_X = [[1, 0, 0], [1, 0, 0], [2, 3, 4]]
attribute_sorted = 1
attribute_axis = 0
output_Y = [[1, 0, 0], [2, 3, 4]]
output_indices = [0, 2]
output_inverse_indices = [0, 0, 1]
output_counts = [2, 1]
*/
int axis = 0;
int sorted = 1;
std::vector<int32_t> data = {1, 0, 0, 1, 0, 0, 2, 3, 4};
std::vector<size_t> lens = {3, 3};
migraphx::shape data_shape{migraphx::shape::int32_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<int32_t> gold_val = {1, 0, 0, 2, 3, 4};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 2};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 0, 1};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 1};
EXPECT(ct == gold_ct);
}
// unique integer sub-tensors: un-sorted (with dup entries)
TEST_CASE(unique_subtensors_neg_axis_test)
{
/*
input_X = [[1, 0, 0], [1, 0, 0], [2, 3, 4]]
attribute_sorted = 0
attribute_axis = 0
output_Y = [[1, 0, 0], [2, 3, 4]]
output_indices = [0, 2]
output_inverse_indices = [0, 0, 1]
output_counts = [2, 1]
*/
int axis = -2; // == 0
int sorted = 0;
std::vector<int32_t> data = {1, 0, 0, 1, 0, 0, 2, 3, 4};
std::vector<size_t> lens = {3, 3};
migraphx::shape data_shape{migraphx::shape::int32_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<int32_t> gold_val = {1, 0, 0, 2, 3, 4};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 2};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 0, 1};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 1};
EXPECT(ct == gold_ct);
}
// unique float sub-tensors: sorted (with dup entries) axis = 0
TEST_CASE(unique_subtensors_zero_axis_test)
{
/*
input_x = [[[1., 1.], [0., 1.], [2., 1.], [0., 1.]],
[[1., 1.], [0., 1.], [2., 1.], [0., 1.]]]
attribute_sorted = 1
attribute_axis = 0
*/
int axis = 0;
int sorted = 1;
std::vector<float> data = {1., 1., 0., 1., 2., 1., 0., 1., 1., 1., 0., 1., 2., 1., 0., 1.};
std::vector<size_t> lens = {2, 4, 2};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {1., 1., 0., 1., 2., 1., 0., 1.};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2};
EXPECT(ct == gold_ct);
}
......@@ -53,7 +53,8 @@ TEST_CASE(rewrite_pooling_test)
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
{"lengths", {3, 4, 5}},
{"dilations", {1, 1, 1}}}),
input);
m.add_return({ret});
return m;
......@@ -80,6 +81,483 @@ TEST_CASE(rewrite_pooling_test)
migraphx::make_op("reduce_max", {{"axes", {2, 3, 4}}}));
}
TEST_CASE(rewrite_pooling_dialtions_test)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 5, 5}};
auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}},
{"dilations", {2, 2}}}),
input);
m.add_return({ret});
return m;
};
auto opt_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
std::vector<int> indices{0, 2, 1, 3, 2, 4};
migraphx::shape s_indices{migraphx::shape::int32_type, {indices.size()}};
auto i1 = m.add_literal(migraphx::literal{s_indices, indices});
auto g1 = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), input, i1);
auto i2 = m.add_literal(migraphx::literal{s_indices, indices});
auto g2 = m.add_instruction(migraphx::make_op("gather", {{"axis", 3}}), g1, i2);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0}},
{"stride", {2, 2}},
{"lengths", {2, 2}},
{"dilations", {1, 1}}}),
g2);
m.add_return({ret});
return m;
};
auto test_rewrite = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(mode);
opt_pooling(m1);
EXPECT(m1 == m2);
};
test_rewrite(migraphx::op::pooling_mode::average);
test_rewrite(migraphx::op::pooling_mode::max);
}
TEST_CASE(rewrite_pooling_dialtions_test2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 5, 5, 5}};
auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {2, 2, 2}},
{"dilations", {2, 2, 2}}}),
input);
m.add_return({ret});
return m;
};
auto opt_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
std::vector<int> indices{0, 2, 1, 3, 2, 4};
migraphx::shape s_indices{migraphx::shape::int32_type, {indices.size()}};
auto i1 = m.add_literal(migraphx::literal{s_indices, indices});
auto g1 = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), input, i1);
auto i2 = m.add_literal(migraphx::literal{s_indices, indices});
auto g2 = m.add_instruction(migraphx::make_op("gather", {{"axis", 3}}), g1, i2);
auto i3 = m.add_literal(migraphx::literal{s_indices, indices});
auto g3 = m.add_instruction(migraphx::make_op("gather", {{"axis", 4}}), g2, i3);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {2, 2, 2}},
{"lengths", {2, 2, 2}},
{"dilations", {1, 1, 1}}}),
g3);
m.add_return({ret});
return m;
};
auto test_rewrite = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(mode);
opt_pooling(m1);
EXPECT(m1 == m2);
};
test_rewrite(migraphx::op::pooling_mode::average);
test_rewrite(migraphx::op::pooling_mode::max);
}
TEST_CASE(rewrite_pooling_dialtions_test3)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 5}};
auto pooling_program = [&]() {
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret =
m.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1}},
{"stride", {1}},
{"lengths", {3}},
{"dilations", {2}}}),
input);
m.add_return({ret});
return m;
};
migraphx::module m1 = pooling_program();
migraphx::module m2 = m1;
opt_pooling(m1);
EXPECT(m1 == m2);
}
TEST_CASE(rewrite_pooling_dialtions_test4)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 5, 5}};
auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {1, 0}},
{"stride", {1, 3}},
{"lengths", {3, 1}},
{"dilations", {1, 2}}}),
input);
m.add_return({ret});
return m;
};
auto opt_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
std::vector<int> col_indices{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
migraphx::shape s_col_indices{migraphx::shape::int32_type, {col_indices.size()}};
std::vector<int> row_indices{0, 3};
migraphx::shape s_row_indices{migraphx::shape::int32_type, {row_indices.size()}};
auto p =
m.add_instruction(migraphx::make_op("pad",
{{"pads", {0, 0, 1, 0, 0, 0, 1, 0}},
{"value", std::numeric_limits<float>::lowest()}}),
input);
auto i1 = m.add_literal(migraphx::literal{s_col_indices, col_indices});
auto g1 = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), p, i1);
auto i2 = m.add_literal(migraphx::literal{s_row_indices, row_indices});
auto g2 = m.add_instruction(migraphx::make_op("gather", {{"axis", 3}}), g1, i2);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0}},
{"stride", {3, 1}},
{"lengths", {3, 1}},
{"dilations", {1, 1}}}),
g2);
m.add_return({ret});
return m;
};
auto test_rewrite = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(mode);
opt_pooling(m1);
EXPECT(m1 == m2);
};
// Average won't work because of padding
test_rewrite(migraphx::op::pooling_mode::max);
}
TEST_CASE(rewrite_pooling_dialtions_test5)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 5, 5}};
auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0}},
{"stride", {2, 3}},
{"lengths", {2, 1}},
{"dilations", {1, 2}}}),
input);
m.add_return({ret});
return m;
};
auto opt_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
std::vector<int> col_indices{0, 1, 2, 3};
migraphx::shape s_col_indices{migraphx::shape::int32_type, {col_indices.size()}};
std::vector<int> row_indices{0, 3};
migraphx::shape s_row_indices{migraphx::shape::int32_type, {row_indices.size()}};
auto i1 = m.add_literal(migraphx::literal{s_col_indices, col_indices});
auto g1 = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), input, i1);
auto i2 = m.add_literal(migraphx::literal{s_row_indices, row_indices});
auto g2 = m.add_instruction(migraphx::make_op("gather", {{"axis", 3}}), g1, i2);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0}},
{"stride", {2, 1}},
{"lengths", {2, 1}},
{"dilations", {1, 1}}}),
g2);
m.add_return({ret});
return m;
};
auto test_rewrite = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(mode);
opt_pooling(m1);
EXPECT(m1 == m2);
};
test_rewrite(migraphx::op::pooling_mode::average);
test_rewrite(migraphx::op::pooling_mode::max);
}
TEST_CASE(rewrite_avgpool_rank3_dil_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {2};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.35, 0.15, 0.85, 0.3, 0.1, 0.65};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_avgpool_rank3_dil_test2)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {3};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.2, 0.45, 0.35};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_avgpool_rank4_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 1};
op.padding = {0, 0};
op.stride = {2, 3};
op.dilations = {1, 2};
std::vector<float> data(25);
std::iota(data.begin(), data.end(), 1);
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{3.5, 6.5, 13.5, 16.5};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_maxpool_rank3_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.dilations = {2};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.2, 0.9, 0.5, 0.1, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_maxpool_rank3_test2)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {1};
op.stride = {1};
op.dilations = {3};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.3, 0.2, 0.9, 0.8, 0.5, 0.1, 0.6, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_maxpool_rank3_test3)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {3};
op.padding = {2};
op.stride = {2};
op.dilations = {3};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.2, 0.5, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_maxpool_rank4_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {3, 1};
op.padding = {1, 0};
op.stride = {1, 3};
op.dilations = {1, 2};
std::vector<float> data(25);
std::iota(data.begin(), data.end(), 1);
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{6, 9, 11, 14, 16, 19, 21, 24, 21, 24};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank5_test)
{
// 3D, input is 5D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {1, 1, 1};
op.dilations = {2, 2, 2};
std::vector<float> data{
-2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, 0.3405,
-0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102, -0.042, -0.7146,
0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622, -0.5691, 0.907, -0.0569,
-1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282, 1.3655, -2.1519, 0.6068, -0.2001,
-0.4702, 0.3864, 1.7083, 0.9096, 0.4286, -1.8866, 0.7034, 0.0293, 1.4587, 0.7672,
-2.8614, 0.8124, -0.053, 1.0449, 0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337,
-0.4644, 0.1938, 0.2889, 0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756,
0.0677, -0.0223, -0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777,
-0.9427, 2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229,
-0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664, 1.1364,
0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.7812, 1.0449, 2.7666, 2.6859};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank5_test2)
{
// 3D, input is 5D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2};
op.padding = {2, 2, 2};
op.stride = {2, 2, 2};
op.dilations = {3, 3, 3};
std::vector<float> data{
-2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, 0.3405,
-0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102, -0.042, -0.7146,
0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622, -0.5691, 0.907, -0.0569,
-1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282, 1.3655, -2.1519, 0.6068, -0.2001,
-0.4702, 0.3864, 1.7083, 0.9096, 0.4286, -1.8866, 0.7034, 0.0293, 1.4587, 0.7672,
-2.8614, 0.8124, -0.053, 1.0449, 0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337,
-0.4644, 0.1938, 0.2889, 0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756,
0.0677, -0.0223, -0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777,
-0.9427, 2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229,
-0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664, 1.1364,
0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
opt_pooling(*mm);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-0.8345, 1.5064, -0.9146, 0.3405, -1.44, 0.1297, 0.5861, -2.8029,
-0.4702, -0.2001, -2.1519, 1.3655, -0.4109, -1.5348, 0.907, -0.5691,
-0.0549, 0.4577, 0.7118, 0.9035, -1.2681, -0.859, -0.0131, 0.845,
-1.1953, -0.6439, 0.5085, -0.4136, -2.6436, 1.0029, -0.7501, 2.6859};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rewrite_avepooling_na1_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
......@@ -92,7 +570,8 @@ TEST_CASE(rewrite_avepooling_na1_test)
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 1, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
{"lengths", {3, 4, 5}},
{"dilations", {1, 1, 1}}}),
input);
m.add_return({ret});
return m;
......@@ -117,7 +596,8 @@ TEST_CASE(rewrite_avepooling_na2_test)
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0}},
{"stride", {1, 2, 1}},
{"lengths", {3, 4, 5}}}),
{"lengths", {3, 4, 5}},
{"dilations", {1, 1, 1}}}),
input);
m.add_return({ret});
return m;
......@@ -141,7 +621,8 @@ TEST_CASE(rewrite_avepooling_na3_test)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}),
{"lengths", {3, 3, 5}},
{"dilations", {1, 1, 1}}}),
input);
m.add_return({ret});
return m;
......@@ -169,7 +650,8 @@ TEST_CASE(literal_rewrite_pooling_test)
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
{"lengths", {3, 4, 5}},
{"dilations", {1, 1, 1}}}),
input);
mm->add_return({ret});
return p;
......
......@@ -1017,6 +1017,40 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
EXPECT(m1 == m2);
}
TEST_CASE(concat_convert_fusion)
{
auto s = migraphx::shape{migraphx::shape::float_type, {64}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto xh = m1.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
x);
auto yh = m1.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
y);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), xh, yh);
m1.add_instruction(pass_op{}, concat);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto concath = m2.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
concat);
m2.add_instruction(pass_op{}, concath);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_div_const)
{
migraphx::module m1;
......
......@@ -155,29 +155,187 @@ TEST_CASE(after_split_dyn_broadcast_match)
EXPECT(p0 == p1);
}
TEST_CASE(const_slice_3input)
TEST_CASE(const_slice_2input_ends_axes)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m0.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"ends", {3}}, {"axes", {0}}}), input, input_starts);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_2input_starts_axes)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_ends = m0.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"axes", {0}}}), input, input_ends);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_2input_starts_ends)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m1.add_instruction(
auto input_axes = m0.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}}), input, input_axes);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_3input_axes_only)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m0.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m0.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}), input, input_starts, input_ends);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_3input_ends_only)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m0.add_literal(migraphx::literal{s1, {0}});
auto input_axes = m0.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"ends", {3}}}), input, input_starts, input_axes);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_3inputs_starts_only)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_ends = m0.add_literal(migraphx::literal{s1, {3}});
auto input_axes = m0.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}}), input, input_ends, input_axes);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_2input_ends_axes_dyn)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{6, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
auto input = m0.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m0.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"ends", {3}}, {"axes", {0}}}), input, input_starts);
m0.add_return({slice_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{6, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
auto input = m1.add_parameter("data", s);
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
......@@ -319,4 +477,98 @@ TEST_CASE(static_dimensions_of_nonfixed)
EXPECT(m0 == m1);
}
TEST_CASE(constant_alloc_reshape)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {3, 32}};
auto input = m0.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape::int64_type, {3}};
auto literal_ins = m0.add_literal(migraphx::literal{lit_s, {3, 4, 8}});
auto alloc_ins = m0.add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}),
literal_ins);
auto reshape_ins = m0.add_instruction(migraphx::make_op("reshape"), input, alloc_ins);
m0.add_return({reshape_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {3, 32}};
auto input = m1.add_parameter("data", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 8}}}), input);
m1.add_return({reshape_ins});
}
EXPECT(m0 == m1);
}
// A more contrived example to test static dimensions_of and constant reshape
TEST_CASE(static_dimensions_of_to_constant_alloc_reshape)
{
migraphx::module m0;
{
migraphx::shape input_shape{migraphx::shape::float_type, {3, 4, 8}};
auto x_param = m0.add_parameter("x", input_shape);
auto dimensions_of_ins =
m0.add_instruction(migraphx::make_op("dimensions_of", {{"end", 3}}), x_param);
migraphx::shape lit_shape{migraphx::shape::int64_type, {1}};
auto lit0 = m0.add_literal(migraphx::literal{lit_shape, {0}});
auto gather_ins =
m0.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), dimensions_of_ins, lit0);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {1}}, {"ends", {3}}, {"axes", {0}}}),
dimensions_of_ins);
auto reduce_ins =
m0.add_instruction(migraphx::make_op("reduce_prod", {{"axes", {0}}}), slice_ins);
auto concat_ins =
m0.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), gather_ins, reduce_ins);
auto alloc_ins = m0.add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), concat_ins);
auto reshape_ins = m0.add_instruction(migraphx::make_op("reshape"), x_param, alloc_ins);
m0.add_return({reshape_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {3, 4, 8}};
auto x_param = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 32}}}), x_param);
m1.add_return({reshape_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(const_alloc_fill)
{
migraphx::module m0;
{
migraphx::shape val_shape{migraphx::shape::int64_type, {1}, {0}};
std::vector<int64_t> lit_data = {3};
auto value_lit = m0.add_literal(migraphx::literal{val_shape, lit_data});
migraphx::shape lit_s{migraphx::shape::int64_type, {3}};
auto output_dim_lit = m0.add_literal(migraphx::literal{lit_s, {3, 4, 4}});
auto alloc_ins = m0.add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::int64_type}}),
output_dim_lit);
auto ret = m0.add_instruction(migraphx::make_op("fill"), value_lit, alloc_ins);
m0.add_return({ret});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape lit_shape{migraphx::shape::int64_type, {3, 4, 4}};
std::vector<int64_t> lit_data(3 * 4 * 4, 3);
auto ret = m1.add_literal(migraphx::literal{lit_shape, lit_data});
m1.add_return({ret});
}
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -788,6 +788,7 @@ TEST_CASE(conv_pooling_dot)
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"dilations", {1, 1}},
{"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
......@@ -835,6 +836,7 @@ TEST_CASE(conv_pooling_dot)
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"dilations", {1, 1}},
{"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
......@@ -896,6 +898,7 @@ TEST_CASE(mobilenet_snippet)
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"dilations", {1, 1}},
{"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment