Unverified Commit 70d9faf7 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into mi200

parents a56c531c a60bdb67
 split_test_invalid_num_outputs:
.
xy1y2y3y4"Split*
num_outputssplit_test_invalid_num_outputsZ
x


b
y1


b
y2


b
y3


b
y4


B
\ No newline at end of file
 split_test_uneven_num_outputs:
.
xy1y2y3y4"Split*
num_outputssplit_test_uneven_num_outputsZ
x


b
y1


b
y2


b
y3


b
y4


B
\ No newline at end of file
 unique_dynamic_sorted_3D_test:Ö
?
XYindicesinverse_indicescounts"Unique*
sorted unique_dynamic_sorted_3D_testZ
X



b
Y

b
indices

b
inverse_indices

@b
counts

B
\ No newline at end of file
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -351,6 +351,87 @@ TEST_CASE(depthtospace_simple_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(dynamicquantizelinear_1d_test)
{
auto p = migraphx::parse_onnx("dynamicquantizelinear_1d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0, 2, -3, -2.5, 1.34, 0.5};
migraphx::shape s_x{migraphx::shape::float_type, {6}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data.data());
auto results = p.eval(pp);
std::vector<uint8_t> y_results;
results.at(0).visit([&](auto output) { y_results.assign(output.begin(), output.end()); });
std::vector<uint8_t> y_gold = {153, 255, 0, 26, 221, 179};
EXPECT(migraphx::verify::verify_rms_range(y_results, y_gold));
std::vector<float> y_scale;
results.at(1).visit([&](auto output) { y_scale.assign(output.begin(), output.end()); });
std::vector<float> y_scale_gold = {0.0196078438};
EXPECT(migraphx::verify::verify_rms_range(y_scale, y_scale_gold));
std::vector<uint8_t> y_zpt;
results.at(2).visit([&](auto output) { y_zpt.assign(output.begin(), output.end()); });
std::vector<uint8_t> y_zpt_gold = {153};
EXPECT(migraphx::verify::verify_rms_range(y_zpt, y_zpt_gold));
}
TEST_CASE(dynamicquantizelinear_1d_max_adjusted_test)
{
auto p = migraphx::parse_onnx("dynamicquantizelinear_1d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> data{-1.0, -2.1, -1.3, -2.5, -3.34, -4.0};
migraphx::shape s_x{migraphx::shape::float_type, {6}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data.data());
auto results = p.eval(pp);
std::vector<uint8_t> y_results;
results.at(0).visit([&](auto output) { y_results.assign(output.begin(), output.end()); });
std::vector<uint8_t> y_gold = {191, 121, 172, 96, 42, 0};
EXPECT(migraphx::verify::verify_rms_range(y_results, y_gold));
std::vector<float> y_scale;
results.at(1).visit([&](auto output) { y_scale.assign(output.begin(), output.end()); });
std::vector<float> y_scale_gold = {0.0156862754};
EXPECT(migraphx::verify::verify_rms_range(y_scale, y_scale_gold));
std::vector<uint8_t> y_zpt;
results.at(2).visit([&](auto output) { y_zpt.assign(output.begin(), output.end()); });
std::vector<uint8_t> y_zpt_gold = {255};
EXPECT(migraphx::verify::verify_rms_range(y_zpt, y_zpt_gold));
}
TEST_CASE(dynamicquantizelinear_2d_test)
{
auto p = migraphx::parse_onnx("dynamicquantizelinear_2d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> data{1.0, 2.1, 1.3, 2.5, 3.34, 4.0, 1.5, 2.6, 3.9, 4.0, 3.0, 2.345};
migraphx::shape s_x{migraphx::shape::float_type, {3, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data.data());
auto results = p.eval(pp);
std::vector<uint8_t> y_results;
results.at(0).visit([&](auto output) { y_results.assign(output.begin(), output.end()); });
std::vector<uint8_t> y_gold = {64, 134, 83, 159, 213, 255, 96, 166, 249, 255, 191, 149};
EXPECT(migraphx::verify::verify_rms_range(y_results, y_gold));
std::vector<float> y_scale;
results.at(1).visit([&](auto output) { y_scale.assign(output.begin(), output.end()); });
std::vector<float> y_scale_gold = {0.0156862754};
EXPECT(migraphx::verify::verify_rms_range(y_scale, y_scale_gold));
std::vector<uint8_t> y_zpt;
results.at(2).visit([&](auto output) { y_zpt.assign(output.begin(), output.end()); });
std::vector<uint8_t> y_zpt_gold = {0};
EXPECT(migraphx::verify::verify_rms_range(y_zpt, y_zpt_gold));
}
TEST_CASE(spacetodepth_simple_test)
{
auto p = migraphx::parse_onnx("spacetodepth_simple_test.onnx");
......@@ -1014,6 +1095,95 @@ TEST_CASE(instance_norm_3d_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(isinf_half_test)
{
migraphx::program p = migraphx::parse_onnx("isinf_half_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::half_type, {2, 3}};
migraphx::parameter_map pp;
migraphx::half nan = std::numeric_limits<migraphx::half>::quiet_NaN();
migraphx::half infinity = std::numeric_limits<migraphx::half>::infinity();
migraphx::half max = std::numeric_limits<migraphx::half>::max();
migraphx::half min = std::numeric_limits<migraphx::half>::min();
migraphx::half val = migraphx::half(3.6);
std::vector<migraphx::half> data = {-infinity, nan, min, val, max, infinity};
pp["t1"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 0, 0, 0, 0, 1};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(isinf_neg_test)
{
migraphx::program p = migraphx::parse_onnx("isinf_neg_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::parameter_map pp;
float nan = std::numeric_limits<float>::quiet_NaN();
float infinity = std::numeric_limits<float>::infinity();
float max = std::numeric_limits<float>::max();
float min = std::numeric_limits<float>::min();
std::vector<float> data = {-infinity, nan, min, 3.6, max, infinity};
pp["t1"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(isinf_double_pos_test)
{
migraphx::program p = migraphx::parse_onnx("isinf_double_pos_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::double_type, {2, 3}};
migraphx::parameter_map pp;
double nan = std::numeric_limits<double>::quiet_NaN();
double infinity = std::numeric_limits<double>::infinity();
double max = std::numeric_limits<double>::max();
double min = std::numeric_limits<double>::min();
std::vector<double> data = {-infinity, nan, min, 3.6, max, infinity};
pp["t1"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 0, 0, 0, 1};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(isinf_no_detect_test)
{
migraphx::program p = migraphx::parse_onnx("isinf_no_detect_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::parameter_map pp;
float nan = std::numeric_limits<float>::quiet_NaN();
float infinity = std::numeric_limits<float>::infinity();
float max = std::numeric_limits<float>::max();
float min = std::numeric_limits<float>::min();
std::vector<double> data = {-infinity, nan, min, 3.6, max, infinity};
pp["t1"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(layer_norm_test)
{
std::vector<float> scale{1.2, 0.8};
......@@ -1434,6 +1604,77 @@ TEST_CASE(mod_test_fmod_different_types)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(multinomial_dyn_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto p = migraphx::parse_onnx("multinomial_dyn_test.onnx", options);
const size_t batch_size(2);
const size_t categories(5);
const size_t sample_size(100000);
p.compile(migraphx::make_target("ref"));
// Distribution function (2 distributions of 5 categories each)
std::vector<int> dist{15, 25, 15, 25, 20, 20, 20, 10, 25, 25};
EXPECT(dist.size() == categories * batch_size);
std::vector<float> data(categories * batch_size);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return log(d); });
// Shape of the probability distribution, which also defines the number of categories
migraphx::shape s{migraphx::shape::float_type, {batch_size, categories}};
migraphx::parameter_map pp;
pp["input"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<int32_t> result_vec(batch_size * sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Make a categorical histogram of output
// for first result in batch
std::vector<int> res_dist(categories, 0);
size_t r = 0;
for(r = 0; r < result_vec.size() / 2; r++)
res_dist[result_vec[r]]++;
// normalizing factors for original and measured distributions
auto dist_sum = std::accumulate(dist.begin(), dist.begin() + 5, 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
// Values approximate the distribution in dist
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.begin() + 5, norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
norm, migraphx::verify::expected{res_norm}, migraphx::verify::tolerance{0.01}));
// Make a categorical histogram of output
// for second result in batch
std::fill(res_dist.begin(), res_dist.end(), 0);
for(; r < result_vec.size(); r++)
res_dist[result_vec[r]]++;
dist_sum = std::accumulate(dist.begin() + 5, dist.end(), 0);
res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::transform(dist.begin() + 5, dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
}
TEST_CASE(nonzero_test)
{
migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx");
......@@ -1526,6 +1767,298 @@ TEST_CASE(qlinearadd_bcast_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_1d_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_1d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {
-31, 51, 125, 30, -17, -125, 121, -19, -13, 52, 18, -70, 97, 15, 56, 42,
-65, -26, 40, -109, -70, 83, 110, -94, 34, 70, 5, -23, -60, -68, 19, 48,
-113, 3, -44, 20, -99, -103, -49, -38, 122, 75, 38, -7, -65, -56, 96, 99,
50, -27, -114, 49, -65, 105, -3, 54, 8, 38, -81, -46, -86, -46, -104, 36,
22, -51, 48, 59, -116, 6, 93, 16, -111, 98, 51, -87, -111, -74, -39, 7,
107, 115, 59, 60, -66, -14, -106, -23, 119, -122, -51, -100, 26, 125, 45, 90};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 32}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
26, 104, 94, 22, -55, 14, 67, 0, 36, 51, -10, 29, 72, 52, 65, 5,
-30, 23, -19, -74, 23, 112, 24, -14, 68, 54, 7, -26, -48, -8, 50, -39,
-4, 4, -24, -85, -60, -28, 58, 114, 72, 31, -20, -44, 36, 114, 90, 28,
-54, -16, 8, 36, 67, 42, 47, 39, -6, -48, -50, -50, -59, -18, 2, 15,
70, -13, -39, 66, 71, -32, 9, 90, -2, -83, -76, -40, 0, 73, 127, 103,
75, 13, -24, -44, -48, 64, 15, -70, -60, -21, 92, 101, 84};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {84, -73, 117, -2, -97, 72, 67, 27, 1, -44, 110, 51,
9, 7, 58, 113, -34, 34, 124, -20, 6, 66, 68, 98,
31, -84, 25, 101, -69, -100, -68, 116, 33, -121, 78, 49,
102, -86, 65, 69, -87, -89, 16, -125, 51, -54, -86, 79};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {4, 127, 127, -41, 127, 127, -6, 125, 127,
76, 127, 127, 32, 78, 127, -128, -128, 127,
-44, -37, 127, -117, -62, 37, -128, -128, -81};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_ceil_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_ceil_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<uint8_t> data_x = {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32};
migraphx::shape s_x{migraphx::shape::uint8_type, {1, 1, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {120, 150, 240, 255};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_dilations_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_dilations_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 1, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {108, 112, 124, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_pads_count_include_pad_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_pads_count_include_pad_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {-30, 50, 91, -87, -21, -113, -16, 6, -128, 104, 82, -126,
54, 41, -71, 62, -11, -111, 13, 104, -43, -48, 30, 85,
-62, -33, -27, -114, 32, -17, 30, -26, -18, 15, 17, 100,
-122, 115, 84, -34, -86, 82, 102, -117, -91, -105, 112, 91};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
15, 43, 94, 62, 34, -16, 4, -31, 10, -6, 29, -13, -67, -45, 43, 27, 4, -83,
-21, -3, -6, 15, -3, 0, -9, 71, 78, 83, 3, -4, 62, 85, 45, 50, 27, 66,
26, -36, -29, 35, 97, 90, 2, -86, -62, 73, 127, 127, -32, -128, -128, -24, 83, 74,
-9, -63, -45, -35, 20, 1, 15, -12, -11, -72, -44, -46, 50, 40, 57, 25, 34, 18,
22, 30, 40, 105, 97, 88, -46, 26, 83, 127, 125, 69, -94, 24, 127, 127, 116, 4,
-128, -83, 83, 127, 127, -1, -66, -79, 40, 124, 127, 18, -19, -77, -15, 86, 127, 83};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_same_lower_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_same_lower_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<uint8_t> data_x = {195, 102, 250, 61, 222, 6, 243, 218, 230, 105, 36, 116,
194, 31, 113, 85, 126, 204, 80, 38, 115, 167, 221, 67,
69, 140, 11, 209, 136, 120, 39, 96, 29, 5, 167, 40,
58, 51, 157, 179, 244, 149, 76, 243, 126, 144, 192, 199};
migraphx::shape s_x{migraphx::shape::uint8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {195, 148, 176, 156, 208, 131, 150, 193, 226, 141, 98, 153,
212, 140, 71, 88, 126, 165, 142, 59, 120, 153, 168, 102,
92, 123, 135, 127, 102, 116, 78, 89, 29, 17, 86, 104,
44, 36, 95, 136, 151, 126, 108, 164, 185, 166, 140, 178};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_same_upper_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_same_upper_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {-61, 102, -6, 61, -34, 6, -13, -38, -26, 105, 36, 116,
-62, 31, 113, 85, 126, -52, 80, 38, 115, -89, -35, 67,
69, -116, 11, -47, -120, 120, 39, 96, 29, 5, -89, 40,
58, 51, -99, -77, -12, -107, 76, -13, 126, -112, -64, -57};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 4, 4}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
-58, -20, -62, -41, -38, 3, -14, 14, -40, 78, 111, 127, -95, 80, 127, 106,
-14, -112, 11, 41, -74, -128, -66, -44, -88, -37, -14, -15, -64, 95, 71, 127,
8, -128, -128, -101, -69, -104, -120, -128, -116, -128, -93, -128, -50, -128, -128, -128};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_2d_strides_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_2d_strides_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {
84, -73, 117, -2, -97, 72, 67, 27, 1, -44, 110, 51, 9, 7, 58, 113,
-34, 34, 124, -20, 6, 66, 68, 98, 31, -84, 25, 101, -69, -100, -68, 116,
33, -121, 78, 49, 102, -86, 65, 69, -87, -89, 16, -125, 51, -54, -86, 79,
-112, -37, -6, 74, 118, -75, -41, 52, 101, -22, -28, -92, -59, -128, 32, 78,
-20, 121, 11, -107, -92, -31, 81, 117, -55, -3, 80, 119, 126, -98, -11, 52,
-4, -66, 37, -57, -16, -33, -12, 100, 55, 2, 27, 62, -15, 64, -74, -21,
-123, 22, -45, 12, 30, 24, 20, 120, -36, -102, -75, -39, -76, 55, 74, -120,
103, 67, -80, -89, -112, 36, 69, 98, 110, -82, 60, 119, 98, 88, 5, 42,
-88, -86, -58, -33, 93, 80, -57, -56, 87, 7, -4, 114, -73, -91, -12, -123,
96, -99, -31, -99, 85, 34, -126, 106, 88, 126, -60, 14, 75, -117, -15, 6,
55, -14, 117, -87, -75, -50, -85, 54, 70, 125, 74, -100, 25, -112, 74, -66,
-116, -102, 1, -75, -107, 83, -120, -66, 57, 29, 62, -45, -103, -56, 90, -53};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 8, 8}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {24, 37, 10, 17, 12, 12, -13, -1, 14, -10, 7, -19};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_3d_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_3d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {
-61, 102, -6, 61, -34, 6, -13, -38, -26, 105, 36, 116, -62, 31, 113, 85, 126,
-52, 80, 38, 115, -89, -35, 67, 69, -116, 11, -47, -120, 120, 39, 96, 29, 5,
-89, 40, 58, 51, -99, -77, -12, -107, 76, -13, 126, -112, -64, -57, 99, -54, 27,
99, 126, -46, -7, 109, 17, 77, 94, -92, 84, -92, 48, 71, 45, -102, 95, 118,
24, 13, -70, 33, 35, -60, 102, 81, 34, 108, -79, 14, -42};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 3, 3, 3, 3}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {56, 114, 49, 39, 32, 127, 3, 45, -4, -13, 8, 22,
-35, -98, 76, 15, 127, 67, 100, 20, 127, 84, 64, 68};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_notset_test)
{
auto p = migraphx::parse_onnx("qlinearaveragepool_notset_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::shape s_x{migraphx::shape::int8_type, {1, 1, 5, 5}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {22};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearaveragepool_nt_cip_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool
auto p = migraphx::parse_onnx("qlinearaveragepool_nt_cip_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<uint8_t> data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::shape s_x{migraphx::shape::uint8_type, {1, 1, 5, 5}};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, data_x.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {18};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearconcat_test)
{
auto p = migraphx::parse_onnx("qlinearconcat_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_t0 = {2, 3};
migraphx::shape s_t0{migraphx::shape::int8_type, {2}};
migraphx::parameter_map pp;
pp["t0"] = migraphx::argument(s_t0, data_t0.data());
std::vector<int8_t> data_t1 = {6, 8, 10};
migraphx::shape s_t1{migraphx::shape::int8_type, {3}};
pp["t1"] = migraphx::argument(s_t1, data_t1.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearconcat_3d_test)
{
auto p = migraphx::parse_onnx("qlinearconcat_3d_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<int8_t> data_t0 = {10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10};
migraphx::shape s_t0{migraphx::shape::int8_type, {3, 4, 2}};
migraphx::parameter_map pp;
pp["t0"] = migraphx::argument(s_t0, data_t0.data());
std::vector<int8_t> data_t1 = {25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25};
migraphx::shape s_t1{migraphx::shape::int8_type, {3, 2, 2}};
pp["t1"] = migraphx::argument(s_t1, data_t1.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {2, 2, 2, 2, 2, 2, 2, 2, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2,
2, 2, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 2, 2, 6, 6, 6, 6};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearconv_test)
{
// https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
......@@ -1659,6 +2192,35 @@ TEST_CASE(qlinearglobalavgpool_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearleakyrelu_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid
migraphx::program p = migraphx::parse_onnx("qlinearleakyrelu_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape x{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_x = {
-128, -124, -120, -116, -112, -108, -104, -100, -96, -92, -88, -84, -80, -76, -72, -68,
-64, -60, -56, -52, -48, -44, -40, -36, -32, -28, -24, -20, -16, -12, -8, -4,
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60,
64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {
-128, -126, -122, -118, -113, -109, -104, -100, -96, -91, -87, -82, -78, -74, -69, -65,
-60, -56, -52, -47, -43, -38, -34, -30, -25, -21, -16, -12, -8, -3, 1, 6,
10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70,
74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126, 127, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearmatmul_1D_test)
{
migraphx::program p = migraphx::parse_onnx("qlinearmatmul_1D_test.onnx");
......@@ -1735,6 +2297,111 @@ TEST_CASE(qlinearmatmul_3D_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearmul_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = migraphx::parse_onnx("qlinearmul_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape a{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_a = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50,
52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76,
78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102,
104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126};
migraphx::shape b{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_b = {128, 126, 124, 122, 120, 118, 116, 114, 112, 110, 108, 106, 104,
102, 100, 98, 96, 94, 92, 90, 88, 86, 84, 82, 80, 78,
76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52,
50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26,
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2};
migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {100, 111, 122, 132, 142, 151, 160, 169, 177, 185, 192, 199, 206,
212, 218, 223, 228, 233, 237, 241, 244, 247, 250, 252, 254, 255,
255, 255, 255, 255, 255, 255, 254, 252, 250, 247, 244, 241, 237,
233, 228, 223, 218, 212, 206, 199, 192, 185, 177, 169, 160, 151,
142, 132, 122, 111, 100, 89, 77, 65, 52, 39, 26, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearmul_bcast_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = migraphx::parse_onnx("qlinearmul_bcast_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape a{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_a = {-64, -62, -60, -58, -56, -54, -52, -50, -48, -46, -44, -42, -40,
-38, -36, -34, -32, -30, -28, -26, -24, -22, -20, -18, -16, -14,
-12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12,
14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38,
40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62};
migraphx::shape b{migraphx::shape::int8_type, {1, 1, 64}};
std::vector<int8_t> data_b = {96, 94, 92, 90, 88, 86, 84, 82, 80, 78, 76, 74, 72,
70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 46,
44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20,
18, 16, 14, 12, 10, 8, 6, 4, 2, 0, -2, -4, -6,
-8, -10, -12, -14, -16, -18, -20, -22, -24, -26, -28, -30};
migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {-128, -128, -128, -128, -128, -128, -128, -128, -128, -126, -118,
-109, -101, -93, -86, -78, -70, -63, -56, -49, -42, -35,
-28, -21, -15, -9, -2, 4, 10, 15, 21, 27, 32,
37, 42, 47, 52, 57, 62, 66, 70, 75, 79, 83,
86, 90, 94, 97, 100, 103, 106, 109, 112, 115, 117,
119, 122, 124, 126, 127, 127, 127, 127, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearsigmoid_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid
migraphx::program p = migraphx::parse_onnx("qlinearsigmoid_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape x{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_x = {
-128, -124, -120, -116, -112, -108, -104, -100, -96, -92, -88, -84, -80, -76, -72, -68,
-64, -60, -56, -52, -48, -44, -40, -36, -32, -28, -24, -20, -16, -12, -8, -4,
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60,
64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(x, data_x.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {-128, -127, -127, -127, -127, -127, -126, -126, -126, -125, -125,
-124, -123, -122, -120, -119, -117, -114, -112, -108, -104, -99,
-94, -87, -80, -71, -62, -51, -39, -27, -13, 1, 15,
29, 43, 56, 69, 81, 92, 101, 110, 117, 124, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
......@@ -1896,6 +2563,43 @@ TEST_CASE(reversesequence_time_verify_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(round_half_test)
{
migraphx::program p = migraphx::parse_onnx("round_half_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape xs{migraphx::shape::half_type, {4, 4}};
std::vector<float> tmp = {-3.51,
-3.5,
-3.49,
-2.51,
-2.50,
-2.49,
-1.6,
-1.5,
-0.51,
-0.5,
0.5,
0.6,
2.4,
2.5,
3.5,
4.5};
std::vector<migraphx::half> data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map param_map;
param_map["x"] = migraphx::argument(xs, data.data());
auto result = p.eval(param_map).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {-4.0, -4.0, -3.0, -3.0, -2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 0.0, 1.0, 2.0, 2.0, 4.0, 4.0};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(selu_test)
{
migraphx::program p = migraphx::parse_onnx("selu_test.onnx");
......@@ -2158,67 +2862,6 @@ TEST_CASE(softsign_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
{
// input data filled with values 1 to nelements
......@@ -2344,4 +2987,131 @@ TEST_CASE(tril_row_one_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> x{2, 1, 1, 3, 4, 3};
std::vector<float> y_gold = {1, 2, 3, 4};
std::vector<size_t> y_idx_gold = {1, 0, 3, 4};
std::vector<size_t> x_idx_gold = {1, 0, 0, 2, 3, 2};
std::vector<size_t> y_ct_gold = {2, 1, 2, 1};
migraphx::shape s{migraphx::shape::float_type, {x.size()}};
migraphx::parameter_map pm;
pm["X"] = migraphx::argument(s, x.data());
auto result = p.eval(pm);
std::vector<float> yvec;
result[0].visit([&](auto out) { yvec.assign(out.begin(), out.end()); });
EXPECT(yvec == y_gold);
std::vector<size_t> y_idx_vec;
result[1].visit([&](auto out) { y_idx_vec.assign(out.begin(), out.end()); });
EXPECT(y_idx_vec == y_idx_gold);
std::vector<size_t> x_idx_vec;
result[2].visit([&](auto out) { x_idx_vec.assign(out.begin(), out.end()); });
EXPECT(x_idx_vec == x_idx_gold);
std::vector<size_t> y_ct_vec;
result[3].visit([&](auto out) { y_ct_vec.assign(out.begin(), out.end()); });
EXPECT(y_ct_vec == y_ct_gold);
}
TEST_CASE(unique_dynamic_unsorted_test)
{
migraphx::program p = migraphx::parse_onnx("unique_dynamic_unsorted_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> x{2, 1, 1, 3, 4, 3};
std::vector<float> y_gold = {2, 1, 3, 4};
std::vector<size_t> y_idx_gold = {0, 1, 3, 4};
std::vector<size_t> x_idx_gold = {0, 1, 1, 2, 3, 2};
std::vector<size_t> y_ct_gold = {1, 2, 2, 1};
migraphx::shape s{migraphx::shape::float_type, {x.size()}};
migraphx::parameter_map pm;
pm["X"] = migraphx::argument(s, x.data());
auto result = p.eval(pm);
std::vector<float> yvec;
result[0].visit([&](auto out) { yvec.assign(out.begin(), out.end()); });
EXPECT(yvec == y_gold);
std::vector<size_t> y_idx_vec;
result[1].visit([&](auto out) { y_idx_vec.assign(out.begin(), out.end()); });
EXPECT(y_idx_vec == y_idx_gold);
std::vector<size_t> x_idx_vec;
result[2].visit([&](auto out) { x_idx_vec.assign(out.begin(), out.end()); });
EXPECT(x_idx_vec == x_idx_gold);
std::vector<size_t> y_ct_vec;
result[3].visit([&](auto out) { y_ct_vec.assign(out.begin(), out.end()); });
EXPECT(y_ct_vec == y_ct_gold);
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -88,7 +88,7 @@ TEST_CASE(allocate_static)
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}));
}
TEST_CASE(allocate_static_input_error)
TEST_CASE(allocate_static_input)
{
migraphx::shape input{migraphx::shape::int64_type, {3}};
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4}};
......@@ -116,7 +116,7 @@ TEST_CASE(allocate_dyn_with_shape_attr)
input);
}
TEST_CASE(allocate_dyn_no_input_error)
TEST_CASE(allocate_dyn_no_input)
{
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
......@@ -124,6 +124,21 @@ TEST_CASE(allocate_dyn_no_input_error)
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
}
TEST_CASE(allocate_shape_and_buf_type_error)
{
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
throws_shape(migraphx::make_op(
"allocate",
{{"shape", migraphx::to_value(shape_attr)}, {"buf_type", migraphx::shape::half_type}}));
}
TEST_CASE(allocate_no_attr_error)
{
migraphx::shape input{migraphx::shape::int64_type, {4}};
throws_shape(migraphx::make_op("allocate"), input);
}
TEST_CASE(argmax_axis0)
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
......@@ -1942,12 +1957,42 @@ TEST_CASE(multibroadcast_3in_dyn_dyn)
expect_shape(expected_shape, migraphx::make_op("multibroadcast"), c_shape, a_shape, b_shape);
}
TEST_CASE(multinomial)
TEST_CASE(multinomial_bool_type)
{
migraphx::shape s{migraphx::shape::float_type, {2, 5}};
migraphx::shape s1{migraphx::shape::float_type, {1, 2}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
int dtype = 0;
throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s);
throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s1, s2);
}
TEST_CASE(multinomial)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 2}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4}};
migraphx::shape s3{migraphx::shape::float_type, {1, 4}};
int dtype = 2;
expect_shape(s3, migraphx::make_op("multinomial", {{"dtype", dtype}}), s1, s2);
}
TEST_CASE(multinomial_0size_input)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 2}};
migraphx::shape s2{migraphx::shape::float_type, {}};
int dtype = 2;
throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s1, s2);
}
TEST_CASE(multinomial_dyn)
{
migraphx::shape s1{migraphx::shape::int32_type, {{2, 3}, {5, 6}}};
migraphx::shape s2{migraphx::shape::int32_type, {{7, 8}, {9, 10}}};
migraphx::shape s3{migraphx::shape::int32_type, {{2, 3}, {9, 10}}};
expect_shape(
s3, migraphx::make_op("multinomial", {{"dtype", migraphx::shape::int32_type}}), s1, s2);
}
TEST_CASE(nms_shape)
......@@ -2157,7 +2202,8 @@ TEST_CASE(pooling_shape0)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1}}}),
{"lengths", {1}},
{"dilations", {1}}}),
input);
}
......@@ -2170,7 +2216,8 @@ TEST_CASE(pooling_shape1)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
{"lengths", {1, 1}},
{"dilations", {1, 1}}}),
input);
}
......@@ -2184,6 +2231,7 @@ TEST_CASE(pooling_shape2)
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2198,6 +2246,7 @@ TEST_CASE(pooling_shape3)
{"padding", {2, 2}},
{"stride", {3, 3}},
{"lengths", {3, 3}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2209,6 +2258,63 @@ TEST_CASE(pooling_shape4)
tiny_input);
}
TEST_CASE(pooling_shape5)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}},
{"dilations", {2, 2}}}),
input);
}
TEST_CASE(pooling_shape6)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {2, 2}},
{"lengths", {1, 1}},
{"dilations", {2, 2}}}),
input);
}
TEST_CASE(pooling_shape7)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"dilations", {3, 3}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(pooling_shape8)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {2, 2}},
{"stride", {1, 1}},
{"lengths", {3, 3}},
{"dilations", {2, 2}}}),
input);
}
TEST_CASE(pooling_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}};
......@@ -2216,7 +2322,8 @@ TEST_CASE(pooling_dyn_shape0)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1}}}),
{"lengths", {1}},
{"dilations", {1}}}),
input);
}
......@@ -2229,7 +2336,8 @@ TEST_CASE(pooling_dyn_shape1)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
{"lengths", {1, 1}},
{"dilations", {1, 1}}}),
input);
}
......@@ -2243,6 +2351,7 @@ TEST_CASE(pooling_dyn_shape2)
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2257,7 +2366,8 @@ TEST_CASE(pooling_dyn_shape3)
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
{"lengths", {1, 1}},
{"dilations", {1, 1}}}),
input);
}
......@@ -2272,6 +2382,7 @@ TEST_CASE(pooling_dyn_shape4)
{"padding", {2, 2}},
{"stride", {3, 3}},
{"lengths", {3, 3}},
{"dilations", {1, 1}},
{"ceil_mode", true}}),
input);
}
......@@ -2571,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes)
}
}
// This uses the permutation to compute the reshape since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshapes to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE(reshape_nonstandard)
{
auto input = migraphx::shape::from_permutation(migraphx::shape::float_type,
{4, 24, 1, 1, 1},
migraphx::invert_permutation({1, 0, 2, 3, 4}));
std::vector<std::pair<std::vector<std::size_t>, std::vector<int64_t>>> tests{
{{4, 24}, {1, 0}},
{{4, 24, 1, 1, 1, 1}, {1, 0, 2, 3, 4, 5}},
{{4, 8, 3, 1, 1}, {2, 0, 1, 3, 4}},
{{4, 1, 3, 4, 2}, {4, 0, 1, 2, 3}},
{{4, 1, 4, 3, 2}, {4, 0, 1, 2, 3}},
{{4, 2, 4, 3}, {3, 0, 1, 2}},
{{4, 2, 12, 1}, {2, 0, 1, 3}},
{{4, 2, 1, 12}, {3, 0, 1, 2}},
{{4, 4, 2, 3}, {3, 0, 1, 2}},
{{4, 8, 1, 3}, {3, 0, 1, 2}},
{{4, 8, 3, 1}, {2, 0, 1, 3}}};
for(const auto& [dims, perm] : tests)
std::vector<std::vector<std::size_t>> tests{{4, 24},
{4, 24, 1, 1, 1, 1},
{4, 8, 3, 1, 1},
{4, 1, 3, 4, 2},
{4, 1, 4, 3, 2},
{4, 2, 4, 3},
{4, 2, 12, 1},
{4, 2, 1, 12},
{4, 4, 2, 3},
{4, 8, 1, 3},
{4, 8, 3, 1}};
for(auto dims : tests)
{
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, dims, migraphx::invert_permutation(perm));
migraphx::shape output = migraphx::shape{migraphx::shape::float_type, dims};
expect_shape(output, migraphx::make_op("reshape", {{"dims", dims}}), input);
}
}
......@@ -2610,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze)
auto input = migraphx::shape::from_permutation(
migraphx::shape::float_type, {2, 16, 16, 1280}, migraphx::invert_permutation({0, 2, 3, 1}));
std::vector<std::size_t> lens = {2, 256, 1280};
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation({0, 2, 1}));
migraphx::shape output = migraphx::shape{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
}
......@@ -2635,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error)
}
}
TEST_CASE(reshape_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}, {32, 16, 2}};
migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_unsqueeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}, {64, 32, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_squeeze)
TEST_CASE(reshape_nonpacked_squeeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_unsqueeze2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}, {0, 0, 80, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze)
TEST_CASE(reshape_broadcast_squeeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}, {0, 0, 0, 16}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
......@@ -2849,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
}
}
TEST_CASE(reshape_lazy_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
......@@ -2863,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_squeeze)
TEST_CASE(reshape_lazy_nonpacked_squeeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
......@@ -2884,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze)
TEST_CASE(reshape_lazy_broadcast_squeeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze_error)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
......@@ -3188,6 +3340,64 @@ TEST_CASE(slice_static_shape)
TEST_CASE(slice_var_inputs_static_shape0)
{
// attr ends and axes set; inputs are (data, input_starts)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"ends", {2, 3}}, {"axes", {1, 2}}}),
input,
starts);
}
TEST_CASE(slice_var_inputs_static_mismatch_error0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"ends", {2, 3, 4}}, {"axes", {0, 1, 2}}}), input, starts);
}
TEST_CASE(slice_var_inputs_static_shape1)
{
// attr starts and axes set; inputs are (data, input_ends)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"axes", {1, 2}}}),
input,
ends);
}
TEST_CASE(slice_var_inputs_static_mismatch_error1)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"axes", {0, 1, 2}}}), input, ends);
}
TEST_CASE(slice_var_inputs_static_shape2)
{
// attr starts and ends set; inputs are (data, input_axes)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"ends", {1, 2}}}),
input,
axes);
}
TEST_CASE(slice_var_inputs_static_mismatch_error2)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"ends", {3, 4, 4}}}), input, axes);
}
TEST_CASE(slice_var_inputs_static_shape3)
{
// attr axes set; inputs are (data, input_starts, input_ends)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
......@@ -3198,7 +3408,57 @@ TEST_CASE(slice_var_inputs_static_shape0)
ends);
}
TEST_CASE(slice_var_inputs_static_shape1)
TEST_CASE(slice_var_inputs_static_mismatch_error3)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"axes", {0, 1, 2}}}), input, starts, ends);
}
TEST_CASE(slice_var_inputs_static_shape4)
{
// attr ends set; inputs are (data, input_starts, input_axes)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"ends", {3, 4}}}),
input,
starts,
axes);
}
TEST_CASE(slice_var_inputs_static_mismatch_error4)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"ends", {3, 3, 3}}}), input, starts, axes);
}
TEST_CASE(slice_var_inputs_static_shape5)
{
// attr starts set; inputs are (data, input_ends, input_axes)
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"starts", {0, 2}}}),
input,
ends,
axes);
}
TEST_CASE(slice_var_inputs_static_mismatch_error5)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"starts", {0, 1, 2}}}), input, ends, axes);
}
TEST_CASE(slice_var_inputs_static_shape6)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
......@@ -3212,7 +3472,7 @@ TEST_CASE(slice_var_inputs_static_shape1)
axes);
}
TEST_CASE(slice_var_inputs_static_error0)
TEST_CASE(slice_var_inputs_static_mismatch_error6)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
......@@ -3223,17 +3483,125 @@ TEST_CASE(slice_var_inputs_static_error0)
TEST_CASE(slice_var_inputs_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
// attr ends and axes set; inputs are (data, input_starts)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"ends", {2, 3}}, {"axes", {1, 2}}}),
input,
starts);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"ends", {2, 3, 4}}, {"axes", {0, 1, 2}}}), input, starts);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
{
// attr starts and axes set; inputs are (data, input_ends)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"axes", {1, 2}}}),
input,
ends);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error1)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"axes", {0, 1, 2}}}), input, ends);
}
TEST_CASE(slice_var_inputs_dyn_shape2)
{
// attr starts and ends set; inputs are (data, input_axes)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"starts", {0, 1}}, {"ends", {8, 8}}}),
input,
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error2)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(
migraphx::make_op("slice", {{"starts", {0, 1, 2}}, {"ends", {3, 4, 4}}}), input, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape3)
{
// attr axes set; inputs are (data, input_starts, input_ends)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 4}, {0, 4}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
TEST_CASE(slice_var_inputs_dyn_mismatch_error3)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"axes", {0, 1, 2}}}), input, starts, ends);
}
TEST_CASE(slice_var_inputs_dyn_shape4)
{
// attr ends set; inputs are (data, input_starts, input_axes)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"ends", {3, 4}}}),
input,
starts,
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error4)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"ends", {3, 3, 3}}}), input, starts, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape5)
{
// attr starts set; inputs are (data, input_ends, input_axes)
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 6}, {0, 6}}},
migraphx::make_op("slice", {{"starts", {0, 2}}}),
input,
ends,
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error5)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
throws_shape(migraphx::make_op("slice", {{"starts", {0, 1, 2}}}), input, ends, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape6)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
......@@ -3247,6 +3615,15 @@ TEST_CASE(slice_var_inputs_dyn_shape1)
axes);
}
TEST_CASE(slice_var_inputs_dyn_mismatch_error6)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {4, 6}, {4, 6}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {3}};
throws_shape(migraphx::make_op("slice"), input, starts, ends, axes);
}
TEST_CASE(slice_dyn_shape0)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
......@@ -3830,6 +4207,40 @@ TEST_CASE(test_squeeze_wrong_axis)
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unique_axis_invalid)
{
migraphx::shape x_shape{migraphx::shape::float_type, {10, 4, 3}};
throws_shape(migraphx::make_op("unique", {{"axis", -1}}), x_shape);
}
TEST_CASE(test_unique_axis_negative)
{
migraphx::shape x_shape{migraphx::shape::float_type, {10, 4, 3}};
std::vector<migraphx::shape::dynamic_dimension> y_dims{{1, 10}, {4, 4}, {3, 3}};
std::vector<migraphx::shape::dynamic_dimension> idx_dims{{1, 10}};
std::vector<migraphx::shape> y_dyn_shape{{migraphx::shape::float_type, y_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims}};
expect_shape(y_dyn_shape, migraphx::make_op("unique", {{"axis", -3}}), x_shape);
}
TEST_CASE(test_unique_axis_none)
{
migraphx::shape x_shape{migraphx::shape::half_type, {10, 4, 3}};
std::vector<migraphx::shape::dynamic_dimension> y_dims{{1, 120}};
std::vector<migraphx::shape::dynamic_dimension> idx_dims{{1, 120}};
std::vector<migraphx::shape> y_dyn_shape{{migraphx::shape::half_type, y_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims}};
expect_shape(y_dyn_shape, migraphx::make_op("unique"), x_shape);
}
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
......
......@@ -28,6 +28,7 @@ set(VENV_ONNX ${CMAKE_BINARY_DIR}/test/py/venv-onnx)
set(REQUIREMENTS ${CMAKE_CURRENT_SOURCE_DIR}/requirements.txt)
set(REQUIREMENTS_ONNX ${CMAKE_CURRENT_SOURCE_DIR}/requirements-onnx.txt)
set(PYTHON_VERSION_TO_DISABLE_ONNX 3.6)
option(MIGRAPHX_DISABLE_VIRTUAL_ENV "Disable python virtual environments" OFF)
function(add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE)
......@@ -44,9 +45,17 @@ function(add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE)
add_test(NAME py_${PYTHON_VERSION}_${FIXTURE_NAME}_initialize_env COMMAND ${PYTHON_EXECUTABLE} -m venv ${VIRTUAL_ENV_DIR}/${PYTHON_VERSION} --clear)
set_tests_properties(py_${PYTHON_VERSION}_${FIXTURE_NAME}_initialize_env PROPERTIES FIXTURES_SETUP ${FIXTURE_NAME}_${PYTHON_VERSION}_INIT_VENV)
set(PYTHON_EXECUTABLE ${VIRTUAL_ENV_DIR}/${PYTHON_VERSION}/bin/python)
if(EXISTS ${REQUIREMENTS_FILE})
add_test(
NAME py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env
COMMAND ${PYTHON_EXECUTABLE} -m pip install -r ${REQUIREMENTS_FILE})
else()
# If there is no requirements file, then there are no packages to install in the virtual env.
# Just create a placeholder test for setting up the required fixture for running the tests.
add_test(
NAME py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env
COMMAND ${PYTHON_EXECUTABLE} -m pip install --help)
endif()
set_tests_properties(py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_INIT_VENV)
set_tests_properties(py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env PROPERTIES FIXTURES_SETUP ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
endif()
......@@ -61,23 +70,31 @@ function(add_py_test NAME SCRIPT FIXTURE_NAME VENV_DIR)
"PYTHONMALLOC=debug"
"MALLOC_CHECK_=3"
)
if(MIGRAPHX_DISABLE_VIRTUAL_ENV)
set(PYTHON_EXECUTABLE ${PYTHON_${PYTHON_VERSION}_EXECUTABLE})
else()
set(PYTHON_EXECUTABLE ${VENV_DIR}/${PYTHON_VERSION}/bin/python)
endif()
if(NOT (${FIXTURE_NAME} STREQUAL "onnx" AND ${PYTHON_VERSION} STREQUAL ${PYTHON_VERSION_TO_DISABLE_ONNX}))
add_test(
NAME test_py_${PYTHON_VERSION}_${NAME}
COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN})
set_tests_properties(test_py_${PYTHON_VERSION}_${NAME} PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
add_custom_target(test_py_${PYTHON_VERSION}_${NAME}
COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}
COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}")
if(NOT MIGRAPHX_DISABLE_VIRTUAL_ENV)
set_tests_properties(test_py_${PYTHON_VERSION}_${NAME} PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
endif()
endif()
endforeach()
endfunction()
add_dependencies(tests migraphx_py)
add_dependencies(check migraphx_py)
add_py_venv_fixture(common ${VENV} ${REQUIREMENTS})
add_py_venv_fixture(onnx ${VENV_ONNX} ${REQUIREMENTS_ONNX})
if(NOT MIGRAPHX_DISABLE_VIRTUAL_ENV)
add_py_venv_fixture(common ${VENV} ${REQUIREMENTS})
add_py_venv_fixture(onnx ${VENV_ONNX} ${REQUIREMENTS_ONNX})
endif()
add_py_test(ref test_cpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(save_load test_save_load.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
......
......@@ -83,7 +83,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_nonmaxsuppression_two_batches_cpu')
backend_test.exclude(r'test_nonmaxsuppression_two_classes_cpu')
backend_test.exclude(r'test_nonzero_example_cpu')
backend_test.exclude(r'test_round_cpu')
backend_test.exclude(r'test_softmax_axis_0_cpu')
backend_test.exclude(r'test_softmax_axis_1_cpu')
backend_test.exclude(r'test_softmax_default_axis_cpu')
......@@ -119,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_convtranspose_1d_cpu')
backend_test.exclude(r'test_det_2d_cpu')
backend_test.exclude(r'test_det_nd_cpu')
backend_test.exclude(r'test_dynamicquantizelinear_cpu')
backend_test.exclude(r'test_dynamicquantizelinear_max_adjusted_cpu')
backend_test.exclude(r'test_dynamicquantizelinear_min_adjusted_cpu')
backend_test.exclude(r'test_edge_pad_cpu')
backend_test.exclude(r'test_einsum_batch_diagonal_cpu')
backend_test.exclude(r'test_einsum_batch_matmul_cpu')
......@@ -135,9 +131,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_hardmax_example_cpu')
backend_test.exclude(r'test_hardmax_negative_axis_cpu')
backend_test.exclude(r'test_hardmax_one_hot_cpu')
backend_test.exclude(r'test_isinf_cpu')
backend_test.exclude(r'test_isinf_negative_cpu')
backend_test.exclude(r'test_isinf_positive_cpu')
backend_test.exclude(r'test_matmulinteger_cpu')
backend_test.exclude(r'test_maxpool_2d_uint8_cpu')
backend_test.exclude(r'test_maxunpool_export_with_output_shape_cpu')
......@@ -194,7 +187,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(
r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_cpu'
)
backend_test.exclude(r'test_qlinearconv_cpu')
backend_test.exclude(r'test_qlinearmatmul_2D_cpu')
backend_test.exclude(r'test_qlinearmatmul_3D_cpu')
backend_test.exclude(r'test_range_float_type_positive_delta_expanded_cpu')
......@@ -578,9 +570,10 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# from OnnxBackendNodeModelTest
backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest
# MaxPool dialtion is partially supported on GPU by a workaround
# But these tests require too large allocations to work properly
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu')
......@@ -638,8 +631,6 @@ def disabled_tests_onnx_1_11_0(backend_test):
# from OnnxBackendNodeModelTest
backend_test.exclude(r'test_roialign_aligned_false_cpu')
backend_test.exclude(r'test_roialign_aligned_true_cpu')
backend_test.exclude(r'test_scatternd_add_cpu')
backend_test.exclude(r'test_scatternd_multiply_cpu')
# errors
# from OnnxBackendNodeModelTest
......@@ -748,8 +739,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
r'test_reduce_sum_square_negative_axes_keepdims_example_cpu')
backend_test.exclude(
r'test_reduce_sum_square_negative_axes_keepdims_random_cpu')
backend_test.exclude(r'test_scatternd_max_cpu')
backend_test.exclude(r'test_scatternd_min_cpu')
# errors
# from OnnxBackendNodeModelTest
......@@ -835,10 +824,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
backend_test.exclude(r'test_scatter_elements_with_reduction_max_cpu')
backend_test.exclude(r'test_scatter_elements_with_reduction_min_cpu')
# The following tests fail due to the CastLike operator being unsupported
backend_test.exclude(r'test_split_1d_uneven_split_opset18_cpu')
backend_test.exclude(r'test_split_2d_uneven_split_opset18_cpu')
def disabled_tests_onnx_1_14_0(backend_test):
# fails
......
......@@ -22,7 +22,6 @@
# THE SOFTWARE.
#####################################################################################
import migraphx
import numpy as np
def test_conv_relu():
......@@ -51,8 +50,12 @@ def test_sub_uint64():
params = {}
shapes = p.get_parameter_shapes()
params["0"] = np.arange(120).reshape(shapes["0"].lens()).astype(np.uint64)
params["1"] = np.arange(20).reshape(shapes["1"].lens()).astype(np.uint64)
params["0"] = migraphx.create_argument(
migraphx.shape(type='uint64_type', lens=shapes["0"].lens()),
list(range(120)))
params["1"] = migraphx.create_argument(
migraphx.shape(type='uint64_type', lens=shapes["1"].lens()),
list(range(20)))
r = p.run(params)
print(r)
......@@ -67,7 +70,9 @@ def test_neg_int64():
params = {}
shapes = p.get_parameter_shapes()
params["0"] = np.arange(6).reshape(shapes["0"].lens()).astype(np.int64)
params["0"] = migraphx.create_argument(
migraphx.shape(type='int64_type', lens=shapes["0"].lens()),
list(range(6)))
r = p.run(params)
print(r)
......@@ -82,8 +87,9 @@ def test_nonzero():
params = {}
shapes = p.get_parameter_shapes()
params["data"] = np.array([1, 1, 0,
1]).reshape(shapes["data"].lens()).astype(bool)
params["data"] = migraphx.create_argument(
migraphx.shape(type='bool_type', lens=shapes["data"].lens()),
[1, 1, 0, 1])
r = p.run(params)
print(r)
......@@ -101,8 +107,8 @@ def test_fp16_imagescaler():
params = {}
shapes = p.get_parameter_shapes()
params["0"] = np.random.randn(768).reshape(shapes["0"].lens()).astype(
np.float16)
params["0"] = migraphx.generate_argument(
migraphx.shape(type='half_type', lens=shapes["0"].lens()), 768)
r = p.run(params)[-1]
print(r)
......@@ -120,10 +126,12 @@ def test_if_pl():
params = {}
shapes = p.get_parameter_shapes()
params["x"] = np.ones(6).reshape(shapes["x"].lens()).astype(np.float32)
params["y"] = np.array([2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0
]).reshape(shapes["y"].lens()).astype(np.float32)
params["cond"] = np.array([1]).reshape(()).astype(bool)
params["x"] = migraphx.fill_argument(
migraphx.shape(type='float_type', lens=shapes["x"].lens()), 1)
params["y"] = migraphx.fill_argument(
migraphx.shape(type='float_type', lens=shapes["y"].lens()), 2.0)
params["cond"] = migraphx.fill_argument(
migraphx.shape(type="bool", lens=[1], strides=[0]), 1)
r = p.run(params)[-1]
print(r)
......
......@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
......@@ -638,11 +638,10 @@ TEST_CASE(dot_float)
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -655,7 +654,8 @@ TEST_CASE(dot_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -717,24 +717,28 @@ TEST_CASE(dot_double_2args)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto scale_a = mm->add_literal(10.0);
auto scale_a_lit = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto scale_b_lit = mm->add_literal(5.0);
auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale);
auto scale_a_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_a_lit);
auto scale_b_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -745,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
......@@ -799,18 +804,15 @@ TEST_CASE(dot_half_1arg)
auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
scale);
auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
{migraphx::quantize_8bits_pass{migraphx::shape::int8_type, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
......@@ -852,9 +855,9 @@ TEST_CASE(conv_float)
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto scale_lit = mm->add_literal(10.0f);
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
......@@ -862,13 +865,11 @@ TEST_CASE(conv_float)
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}};
std::vector<float> vec(sc.elements(), 100.0f);
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()};
auto d_scale = mm->add_literal(100.0f);
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -878,7 +879,9 @@ TEST_CASE(conv_float)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
......@@ -903,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
test::throws([&] {
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, quant_params}});
});
}
......@@ -931,19 +936,20 @@ TEST_CASE(conv_half)
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0}));
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -953,7 +959,9 @@ TEST_CASE(conv_half)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
......@@ -1187,9 +1195,9 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto s1_mb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
......@@ -1199,23 +1207,24 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
// else submod
auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f);
auto sax_lit = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax);
auto sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit);
auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw);
auto ssw_lit = else_mod->add_literal(1.66667f);
auto ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit);
auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f);
so1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1);
auto ssw_mb = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
ssw_lit);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});
......@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph)
std::size_t param_index = 0;
migraphx::run_passes(
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}});
migraphx::run_passes(p1,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type,
quant_params}});
optimize_prog_int8(p1);
auto p2 = create_int8_program();
......
......@@ -30,7 +30,7 @@
#include <test.hpp>
TEST_CASE(allocate_dyn)
TEST_CASE(allocate_dyn0)
{
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -47,3 +47,21 @@ TEST_CASE(allocate_dyn)
migraphx::shape sresult{migraphx::shape::float_type, {2, 3, 4, 4}};
result.visit([&](auto output) { EXPECT(output.get_shape() == sresult); });
}
TEST_CASE(allocate_dyn1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4}};
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
auto out_dims = mm->add_parameter("out_dims", s);
mm->add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(out_shape)}}),
out_dims);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> data = {2, 3, 4, 4};
params["out_dims"] = migraphx::argument(s, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) { EXPECT(output.get_shape() == out_shape); });
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(isinf_double_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3}};
auto inf_val = std::numeric_limits<double>::infinity();
std::vector<double> data0 = {1.2, 5.2, inf_val, -inf_val, 0., 100.};
auto l1 = mm->add_literal(migraphx::literal{s, data0});
mm->add_instruction(migraphx::make_op("isinf"), l1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<double> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<double> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(isinf_float_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto inf_val = std::numeric_limits<float>::infinity();
std::vector<float> data0 = {1.2, 5.2, inf_val, -inf_val, 0., 100.};
auto l1 = mm->add_literal(migraphx::literal{s, data0});
mm->add_instruction(migraphx::make_op("isinf"), l1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(isinf_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {2, 3}};
auto inf_val = std::numeric_limits<migraphx::half>::infinity();
migraphx::half a{1.2};
migraphx::half b{5.2};
std::vector<migraphx::half> data0 = {a, b, inf_val, -inf_val, b, a};
auto l1 = mm->add_literal(migraphx::literal{s, data0});
mm->add_instruction(migraphx::make_op("isinf"), l1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(isinf_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {3, 8}}};
auto input = mm->add_parameter("X", s);
auto inf_val = std::numeric_limits<migraphx::half>::infinity();
mm->add_instruction(migraphx::make_op("isinf"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data = {1.2, 5.2, inf_val, -inf_val, 0., 100.};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
......@@ -24,9 +24,10 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <numeric>
#include <random>
#include <test.hpp>
......@@ -48,27 +49,37 @@ TEST_CASE(multinomial_test)
migraphx::shape s{migraphx::shape::float_type, {1, 5}};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
auto input = mm->add_literal(migraphx::literal(s, data));
std::vector<float> sum(5);
// convert to float
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return d; });
// take cumulative sum
std::partial_sum(data.begin(), data.end(), sum.begin(), std::plus<float>());
// scale probabilities arbitrarily
float odd_scale = 10000.;
std::transform(sum.begin(), sum.end(), data.begin(), [&](auto d) { return d * odd_scale; });
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto mb_maxes =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5}}}), maxes);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
auto input = mm->add_literal(migraphx::literal(s, data));
mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
mm->add_instruction(migraphx::make_op("multinomial"), input, rs_lit);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
// result_vec contains an index, or category label, for each random input value
std::vector<int32_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// res_dist is a count, or histogram, of the number of samples in each category. This is the
// sampled distribution.
std::vector<int> res_dist(5, 0);
for(const auto& r : result_vec)
res_dist[r]++;
// To check the result, normalize the original probability distribution dist
// and the sampling result res_dist; they should be close
// Total the unnormalized probabilities
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
// Total the number of values returned
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
......@@ -78,6 +89,204 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
}
TEST_CASE(multinomial_dyn_test)
{
// Invokes random_uniform and multinomial ops together, to verify the interface
// Dynamic Batch dimension input of 2 means there are 2 different probability
// distribution functions contained in Input_2
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 100000;
size_t batch_size = 2;
// Shape of the random data
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}};
auto input = mm->add_parameter("Input_1", rs);
// Runtime randomization seed
// To seed the random_uniform, we can provide a value by literal or input,
// or ask the system to auto-seed with random_seed op.
migraphx::shape seed_shape{migraphx::shape::uint32_type,
{migraphx::shape::dynamic_dimension{0, 1}}};
auto seed_input = mm->add_parameter("Seed", seed_shape);
// Shape of the probability distribution, which also defines the number of categories
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {5, 6}}};
// Unnormalized distributions for batch size 2:
// 15, 25, 15, 15, 20
// 20, 20, 10, 25, 25
std::vector<int> dist{15, 25, 15, 25, 20, 20, 20, 10, 25, 25};
// Hard-coded non-normalized, accumulated distribution follows:
std::vector<float> data{.15f, .40f, .55f, .80f, 1.0f, 20.f, 40.f, 50.f, 75.f, 100.f};
auto input2 = mm->add_parameter("Input_2", s);
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
mm->add_instruction(migraphx::make_op("multinomial"), input2, randoms);
p.compile(migraphx::make_target("ref"));
// Create a dummy input in the shape we want for the random data
std::vector<float> dummy(sample_size, 0);
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {batch_size, sample_size}};
migraphx::shape input_fixed_shape2{migraphx::shape::float_type, {batch_size, 5}};
migraphx::parameter_map params0;
params0["Input_1"] = migraphx::argument(input_fixed_shape1, dummy.data());
migraphx::shape seed_fixed_shape{migraphx::shape::uint32_type, {1}};
std::vector<uint32_t> seed_data = {4};
params0["Seed"] = migraphx::argument(seed_fixed_shape, seed_data.data());
params0["Input_2"] = migraphx::argument(input_fixed_shape2, data.data());
auto result = p.eval(params0).back();
std::vector<float> result_vec(input_fixed_shape2.elements());
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Make a categorical histogram of output
std::vector<int> res_dist(5, 0);
size_t r = 0;
for(r = 0; r < result_vec.size() / 2; r++)
res_dist[result_vec[r]]++;
// histogram for second set of batch
std::vector<int> res_dist2(5, 0);
for(; r < result_vec.size(); r++)
res_dist2[result_vec[r]]++;
// Rescale or normalize both the input probability distribution and the output
// histogram, and compare. Should be close but not identical.
auto dist_sum = std::accumulate(dist.begin(), dist.begin() + 5, 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.begin() + 5, norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
// Do the same rescaling for the 2nd in batch, which has a different probability distribution
dist_sum = std::accumulate(dist.begin() + 5, dist.end(), 0);
res_dist_sum = std::accumulate(res_dist2.begin(), res_dist2.end(), 0);
std::transform(dist.begin() + 5, dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist2.begin(), res_dist2.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
}
TEST_CASE(multinomial_float_dyn_test)
{
// int data type for random_uniform op and float data type for multinomial.
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 100000;
size_t batch_size = 2;
// Shape of the random data
migraphx::shape rs{migraphx::shape::int32_type, {{1, 2}, {2, sample_size + 1}}};
auto input = mm->add_parameter("Input_1", rs);
// Runtime randomization seed
// To seed the random_uniform, we can provide a value by literal or input,
// or ask the system to auto-seed with random_seed op.
migraphx::shape seed_shape{migraphx::shape::uint32_type,
{migraphx::shape::dynamic_dimension{0, 1}}};
auto seed_input = mm->add_parameter("Seed", seed_shape);
// Shape of the probability distribution, which also defines the number of categories
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {5, 6}}};
// Unnormalized distributions for batch size 2:
// 15, 25, 15, 15, 20
// 20, 20, 10, 25, 25
std::vector<int> dist{15, 25, 15, 25, 20, 20, 20, 10, 25, 25};
// Hard-coded normalized, accumulated distribution follows:
std::vector<float> data{.15f, .40f, .55f, .80f, 1.0f, .20f, .40f, .50f, .75f, 1.0f};
auto input2 = mm->add_parameter("Input_2", s);
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", migraphx::shape::float_type}}),
input2,
randoms);
p.compile(migraphx::make_target("ref"));
// Create a dummy input in the shape we want for the random data
std::vector<float> dummy(sample_size, 0);
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {batch_size, sample_size}};
migraphx::shape input_fixed_shape2{migraphx::shape::float_type, {batch_size, 5}};
migraphx::parameter_map params0;
params0["Input_1"] = migraphx::argument(input_fixed_shape1, dummy.data());
migraphx::shape seed_fixed_shape{migraphx::shape::uint32_type, {1}};
std::vector<uint32_t> seed_data = {4};
params0["Seed"] = migraphx::argument(seed_fixed_shape, seed_data.data());
params0["Input_2"] = migraphx::argument(input_fixed_shape2, data.data());
auto result = p.eval(params0).back();
std::vector<float> result_vec(input_fixed_shape2.elements());
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Make a categorical histogram of output
std::vector<int> res_dist(5, 0);
size_t r = 0;
for(r = 0; r < result_vec.size() / 2; r++)
res_dist[result_vec[r]]++;
// histogram for second set of batch
std::vector<int> res_dist2(5, 0);
for(; r < result_vec.size(); r++)
res_dist2[result_vec[r]]++;
// Rescale or normalize both the input probability distribution and the output
// histogram, and compare. Should be close but not identical.
auto dist_sum = std::accumulate(dist.begin(), dist.begin() + 5, 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.begin() + 5, norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
// Do the same rescaling for the 2nd in batch, which has a different probability distribution
dist_sum = std::accumulate(dist.begin() + 5, dist.end(), 0);
res_dist_sum = std::accumulate(res_dist2.begin(), res_dist2.end(), 0);
std::transform(dist.begin() + 5, dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist2.begin(), res_dist2.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment