Unverified Commit 0b2bcf2c authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into add_parity_check_ci

parents fbacefca dcc7b0a5
...@@ -43,7 +43,7 @@ TEST_CASE(sign_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sign_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sign_dyn_test) TEST_CASE(sign_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(sign_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(sign_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(sin_test) ...@@ -45,7 +45,7 @@ TEST_CASE(sin_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sin_dyn_test) TEST_CASE(sin_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(sin_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(sin_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(sinh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(sinh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sinh_dynamic_test) TEST_CASE(sinh_dynamic_test)
...@@ -67,5 +67,5 @@ TEST_CASE(sinh_dynamic_test) ...@@ -67,5 +67,5 @@ TEST_CASE(sinh_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -48,7 +48,7 @@ TEST_CASE(slice_test_1) ...@@ -48,7 +48,7 @@ TEST_CASE(slice_test_1)
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -72,7 +72,7 @@ TEST_CASE(slice_test_2) ...@@ -72,7 +72,7 @@ TEST_CASE(slice_test_2)
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -99,7 +99,7 @@ TEST_CASE(slice_var_inputs_static0) ...@@ -99,7 +99,7 @@ TEST_CASE(slice_var_inputs_static0)
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2); std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_static1) TEST_CASE(slice_var_inputs_static1)
...@@ -125,7 +125,7 @@ TEST_CASE(slice_var_inputs_static1) ...@@ -125,7 +125,7 @@ TEST_CASE(slice_var_inputs_static1)
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2); std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_static2) TEST_CASE(slice_var_inputs_static2)
...@@ -154,7 +154,7 @@ TEST_CASE(slice_var_inputs_static2) ...@@ -154,7 +154,7 @@ TEST_CASE(slice_var_inputs_static2)
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2); std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_dyn) TEST_CASE(slice_var_inputs_dyn)
...@@ -182,7 +182,7 @@ TEST_CASE(slice_var_inputs_dyn) ...@@ -182,7 +182,7 @@ TEST_CASE(slice_var_inputs_dyn)
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_dyn_test0) TEST_CASE(slice_dyn_test0)
...@@ -213,7 +213,7 @@ TEST_CASE(slice_dyn_test0) ...@@ -213,7 +213,7 @@ TEST_CASE(slice_dyn_test0)
std::vector<int> results_vector(2 * 1 * 2); std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -244,6 +244,6 @@ TEST_CASE(slice_dyn_test1) ...@@ -244,6 +244,6 @@ TEST_CASE(slice_dyn_test1)
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -33,9 +33,9 @@ ...@@ -33,9 +33,9 @@
TEST_CASE(softmax_simple_test) TEST_CASE(softmax_simple_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> a = {0.25, 0.75}; std::vector<float> a = {0.25, 0.75};
std::vector<float> s = {0.377541, 0.622459}; std::vector<float> gold = {0.377541, 0.622459};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al); mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
...@@ -43,7 +43,7 @@ TEST_CASE(softmax_simple_test) ...@@ -43,7 +43,7 @@ TEST_CASE(softmax_simple_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(2); std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
...@@ -76,7 +76,7 @@ TEST_CASE(softmax_test) ...@@ -76,7 +76,7 @@ TEST_CASE(softmax_test)
2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01, 2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01,
-6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01}; -6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
std::vector<float> s = { std::vector<float> gold = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741, 0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758, 0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111, 0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
...@@ -103,7 +103,7 @@ TEST_CASE(softmax_test) ...@@ -103,7 +103,7 @@ TEST_CASE(softmax_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(softmax_dyn_test) TEST_CASE(softmax_dyn_test)
...@@ -147,7 +147,7 @@ TEST_CASE(softmax_dyn_test) ...@@ -147,7 +147,7 @@ TEST_CASE(softmax_dyn_test)
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> s = { std::vector<float> gold = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741, 0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758, 0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111, 0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
...@@ -166,5 +166,5 @@ TEST_CASE(softmax_dyn_test) ...@@ -166,5 +166,5 @@ TEST_CASE(softmax_dyn_test)
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739, 0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723, 0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
0.42914796}; 0.42914796};
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(sqdiff_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sqdiff_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4}; std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sqdiff_dyn_test) TEST_CASE(sqdiff_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(sqdiff_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(sqdiff_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4}; std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(sqrt_test) ...@@ -45,7 +45,7 @@ TEST_CASE(sqrt_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sqrt_dynamic_test) TEST_CASE(sqrt_dynamic_test)
...@@ -68,5 +68,5 @@ TEST_CASE(sqrt_dynamic_test) ...@@ -68,5 +68,5 @@ TEST_CASE(sqrt_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(sub_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sub_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2}; std::vector<float> gold = {-2, -2, -2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sub_dyn_test) TEST_CASE(sub_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(sub_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(sub_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2}; std::vector<float> gold = {-2, -2, -2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(tan_test) ...@@ -45,7 +45,7 @@ TEST_CASE(tan_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(tan_dynamic_test) TEST_CASE(tan_dynamic_test)
...@@ -68,5 +68,5 @@ TEST_CASE(tan_dynamic_test) ...@@ -68,5 +68,5 @@ TEST_CASE(tan_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(tanh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(tanh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(tanh_dynamic_test) TEST_CASE(tanh_dynamic_test)
...@@ -68,5 +68,5 @@ TEST_CASE(tanh_dynamic_test) ...@@ -68,5 +68,5 @@ TEST_CASE(tanh_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -59,7 +59,7 @@ TEST_CASE(transpose_test) ...@@ -59,7 +59,7 @@ TEST_CASE(transpose_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -86,5 +86,5 @@ TEST_CASE(transpose_dyn_test) ...@@ -86,5 +86,5 @@ TEST_CASE(transpose_dyn_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -54,7 +54,7 @@ TEST_CASE(where_test) ...@@ -54,7 +54,7 @@ TEST_CASE(where_test)
for(int i = 0; i < gold.size(); ++i) for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i]; gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold));
} }
TEST_CASE(where_dyn_test) TEST_CASE(where_dyn_test)
...@@ -85,7 +85,7 @@ TEST_CASE(where_dyn_test) ...@@ -85,7 +85,7 @@ TEST_CASE(where_dyn_test)
std::vector<float> results_vector(3 * 3); std::vector<float> results_vector(3 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1}; std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(where_broadcasted_inputs_test) TEST_CASE(where_broadcasted_inputs_test)
...@@ -113,5 +113,5 @@ TEST_CASE(where_broadcasted_inputs_test) ...@@ -113,5 +113,5 @@ TEST_CASE(where_broadcasted_inputs_test)
for(int i = 0; i < gold.size(); ++i) for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i]; gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold));
} }
...@@ -140,24 +140,6 @@ TEST_CASE(handling_tensors) ...@@ -140,24 +140,6 @@ TEST_CASE(handling_tensors)
-0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193, -0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193,
0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292}; 0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292};
// Solution vector
std::vector<float> sol = {-0.20817225,
0.87965256,
0.14958936,
-1.24887264,
-0.06540672,
0.20778663,
0.40456355,
-0.99900877,
0.4917807,
0.1994698,
0.64205718,
0.37798831,
-0.25315839,
0.44276932,
-0.16138598,
0.79344082};
// Create the arguments in a parameter_map // Create the arguments in a parameter_map
migraphx::parameter_map params; migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data()); params["X"] = migraphx::argument(input_shape, a.data());
...@@ -167,8 +149,25 @@ TEST_CASE(handling_tensors) ...@@ -167,8 +149,25 @@ TEST_CASE(handling_tensors)
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// Solution vector
EXPECT(migraphx::verify::verify_range(results_vector, sol)); std::vector<float> gold = {-0.20817225,
0.87965256,
0.14958936,
-1.24887264,
-0.06540672,
0.20778663,
0.40456355,
-0.99900877,
0.4917807,
0.1994698,
0.64205718,
0.37798831,
-0.25315839,
0.44276932,
-0.16138598,
0.79344082};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -50,10 +50,10 @@ TEST_CASE(rewrite_pooling_test) ...@@ -50,10 +50,10 @@ TEST_CASE(rewrite_pooling_test)
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", mode}, {{"mode", mode},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}), {"lengths", {3, 4, 5}}}),
input); input);
m.add_return({ret}); m.add_return({ret});
return m; return m;
...@@ -62,11 +62,8 @@ TEST_CASE(rewrite_pooling_test) ...@@ -62,11 +62,8 @@ TEST_CASE(rewrite_pooling_test)
auto opt_program = [&](const migraphx::operation& reduce_op) { auto opt_program = [&](const migraphx::operation& reduce_op) {
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input); auto rdm = m.add_instruction(reduce_op, input);
auto rdm = m.add_instruction(reduce_op, rsp); m.add_return({rdm});
auto ret =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
m.add_return({ret});
return m; return m;
}; };
...@@ -78,8 +75,9 @@ TEST_CASE(rewrite_pooling_test) ...@@ -78,8 +75,9 @@ TEST_CASE(rewrite_pooling_test)
}; };
test_rewrite(migraphx::op::pooling_mode::average, test_rewrite(migraphx::op::pooling_mode::average,
migraphx::make_op("reduce_mean", {{"axes", {1}}})); migraphx::make_op("reduce_mean", {{"axes", {2, 3, 4}}}));
test_rewrite(migraphx::op::pooling_mode::max, migraphx::make_op("reduce_max", {{"axes", {1}}})); test_rewrite(migraphx::op::pooling_mode::max,
migraphx::make_op("reduce_max", {{"axes", {2, 3, 4}}}));
} }
TEST_CASE(rewrite_avepooling_na1_test) TEST_CASE(rewrite_avepooling_na1_test)
...@@ -140,10 +138,10 @@ TEST_CASE(rewrite_avepooling_na3_test) ...@@ -140,10 +138,10 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}), {"lengths", {3, 3, 5}}}),
input); input);
m.add_return({ret}); m.add_return({ret});
return m; return m;
...@@ -168,10 +166,10 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -168,10 +166,10 @@ TEST_CASE(literal_rewrite_pooling_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_literal(migraphx::literal(s, data)); auto input = mm->add_literal(migraphx::literal(s, data));
auto ret = mm->add_instruction(migraphx::make_op("pooling", auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", mode}, {{"mode", mode},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}), {"lengths", {3, 4, 5}}}),
input); input);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
...@@ -199,7 +197,7 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -199,7 +197,7 @@ TEST_CASE(literal_rewrite_pooling_test)
auto result1 = p1.eval({}).back(); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
visit_all(result1, result2)( visit_all(result1, result2)(
[&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); }); [&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_rms_range(r1, r2)); });
}; };
test_rewrite_pooling(migraphx::op::pooling_mode::max, test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target) ...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125}; std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -956,13 +956,13 @@ TEST_CASE(test_with_type) ...@@ -956,13 +956,13 @@ TEST_CASE(test_with_type)
TEST_CASE(test_multi_index) TEST_CASE(test_multi_index)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}}; migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0})); EXPECT(migraphx::verify::verify_rms_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4})); EXPECT(migraphx::verify::verify_rms_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0})); EXPECT(migraphx::verify::verify_rms_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2})); EXPECT(migraphx::verify::verify_rms_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0})); EXPECT(migraphx::verify::verify_rms_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0})); EXPECT(migraphx::verify::verify_rms_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4})); EXPECT(migraphx::verify::verify_rms_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
} }
TEST_CASE(find_permutation_2d_standard) TEST_CASE(find_permutation_2d_standard)
......
...@@ -2910,6 +2910,179 @@ TEST_CASE(reorder_reshape_slice_not_apply) ...@@ -2910,6 +2910,179 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
TEST_CASE(reorder_reshape_slice_multi_rsp)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {4, 128, 3, 32, 80}};
auto input = m1.add_parameter("input", s);
auto t1 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), t1);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), t1);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), t1);
auto c1_1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2_1 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto r1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
auto r2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c2);
auto r1_1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c1_1);
auto r2_1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c2_1);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c0);
auto t2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), r1_1);
auto c_t2 = m1.add_instruction(migraphx::make_op("contiguous"), t2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), r0, c_t2);
m1.add_return({r1, r2, dot, r2_1});
};
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {4, 128, 3, 32, 80}};
auto input = m2.add_parameter("input", s);
auto t1 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input);
auto c_t1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), c_t1);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {256}}, {"ends", {384}}}), rsp1);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), rsp1);
auto t_slc1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), slc1);
auto c_t_slc1 = m2.add_instruction(migraphx::make_op("contiguous"), t_slc1);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), rsp1);
auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1);
auto c_t1_1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), c_t1_1);
auto slc2_1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), rsp2);
auto slc2_2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {12}}}), rsp2);
m2.add_return({slc2_1, slc2_2, dot, slc0});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_partial)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {16}}, {"ends", {24}}}), input);
auto slc3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {24}}, {"ends", {128}}}), input);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {2, 4, 96};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret, slc3});
};
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m2.add_parameter("input", s);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {32, 4, 96}}}), input);
auto slc3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {24}}, {"ends", {128}}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {6}}}), rsp);
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret, slc3});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_uneven_slice)
{
auto create_p = [] {
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m.add_parameter("input", s);
auto slc0 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {31}}}), input);
auto slc1 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {31}}, {"ends", {62}}}), input);
auto slc2 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {62}}, {"ends", {93}}}), input);
auto slc3 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {93}}, {"ends", {128}}}), input);
auto c0 = m.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 31, 96};
auto r0 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m.add_instruction(migraphx::make_op("mul"), sum, r2);
m.add_return({ret, slc3});
return m;
};
auto m1 = create_p();
auto m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
template <std::size_t BS> template <std::size_t BS>
void reorder_reshape_slice_diff_dims() void reorder_reshape_slice_diff_dims()
{ {
...@@ -2931,13 +3104,32 @@ void reorder_reshape_slice_diff_dims() ...@@ -2931,13 +3104,32 @@ void reorder_reshape_slice_diff_dims()
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 32}; std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32}; std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
m1.add_return({r0, r1, r2}); m1.add_return({r0, r1, r2});
}; };
auto m2 = m1; migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 96, 96}};
auto input = m2.add_parameter("input", s);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
auto c1 = m2.add_instruction(migraphx::make_op("contiguous"), slc1);
std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32};
auto r1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c1);
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 96};
auto r_new = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), r_new);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), r_new);
m2.add_return({slc0, r1, slc2});
};
run_pass(m1); run_pass(m1);
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::simplify_dyn_ops{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(static_broadcast)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = m0.add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = m0.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), literal_ins);
auto add_ins = m0.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m0.add_return({add_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = m1.add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit =
m1.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m1.add_return({add_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(static_multibroadcast)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}, {0}}};
auto literal_ins = m0.add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = m0.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), literal_ins);
auto add_ins = m0.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m0.add_return({add_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}, {0}}};
auto literal_ins = m1.add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
m1.add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m1.add_return({add_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(after_split_dyn_broadcast_match)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", sm_shape.lens()}}),
literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
migraphx::run_passes(p1,
{migraphx::split_single_dyn_dim{},
migraphx::dead_code_elimination{},
migraphx::simplify_dyn_ops{}});
EXPECT(p0 == p1);
}
TEST_CASE(const_slice_3input)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m0.add_return({slice_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}), input, input_starts, input_ends);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_3input_dyn)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{6, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m0.add_return({slice_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{6, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
auto input = m1.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}), input, input_starts, input_ends);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_4input)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m0.add_return({slice_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto input_axes = m1.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice"), input, input_starts, input_ends, input_axes);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -479,11 +479,11 @@ TEST_CASE(conv_pooling_dot) ...@@ -479,11 +479,11 @@ TEST_CASE(conv_pooling_dot)
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution", auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}}, {{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"group", 1}, {"group", 1},
{"padding_mode", 0}}), {"padding_mode", 0}}),
d5, d5,
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
...@@ -526,11 +526,11 @@ TEST_CASE(conv_pooling_dot) ...@@ -526,11 +526,11 @@ TEST_CASE(conv_pooling_dot)
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}}, {{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"group", 1}, {"group", 1},
{"padding_mode", 0}}), {"padding_mode", 0}}),
q1, q1,
weights); weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
...@@ -585,11 +585,11 @@ TEST_CASE(mobilenet_snippet) ...@@ -585,11 +585,11 @@ TEST_CASE(mobilenet_snippet)
auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(mm, "dequantizelinear", q1, scale, zero); auto d5 = add_quantize_op(mm, "dequantizelinear", q1, scale, zero);
auto c1 = mm.add_instruction(migraphx::make_op("convolution", auto c1 = mm.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}}, {{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"group", 1}, {"group", 1},
{"padding_mode", 0}}), {"padding_mode", 0}}),
d5, d5,
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness) ...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back(); auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16); std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_rms_range(rv1, rv2));
} }
TEST_CASE(dot_correctness) TEST_CASE(dot_correctness)
...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness) ...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto result2 = p2.eval({{"a", a}, {"b", b}}).back(); auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements()); std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_rms_range(rv1, rv2));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -67,6 +67,55 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens, ...@@ -67,6 +67,55 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return m; return m;
} }
TEST_CASE(broadcast_transpose_scalar)
{
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), l);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}});
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), l);
m2.add_return({mb});
}
EXPECT(m1 == m2);
}
TEST_CASE(broadcast_transpose_scalar_multi_use)
{
// multibroadcast used more than once
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), l);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), mb);
auto id = m1.add_instruction(migraphx::make_op("identity"), mb);
m1.add_return({t1, id});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}});
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), l);
auto mb2 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), l);
auto id = m2.add_instruction(migraphx::make_op("identity"), mb2);
m2.add_return({mb, id});
}
EXPECT(m1 == m2);
}
TEST_CASE(double_contig) TEST_CASE(double_contig)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment