Unverified Commit 913ae362 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into optimize

parents f1e16656 b8c8d09b
This diff is collapsed.
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
template <class T> template <class T>
void matmul_test() void dot_2d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -82,11 +82,11 @@ void matmul_test() ...@@ -82,11 +82,11 @@ void matmul_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(matmul_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
template <class T> template <class T>
void matmul_test_ex() void dot_4d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -133,10 +133,10 @@ void matmul_test_ex() ...@@ -133,10 +133,10 @@ void matmul_test_ex()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test_ex<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
TEST_CASE(matmul_mutli_dim_2) TEST_CASE(dot_3D_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2) ...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim_2_beta0) TEST_CASE(dot_3D_C_test0)
{ {
migraphx::program p; migraphx::program p;
...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_beta_0) TEST_CASE(dot_3D_C_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0) ...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(matmul_mutli_dim_2_3) TEST_CASE(dot_4D_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3) ...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim1_2_3) TEST_CASE(dot_4D_alpha_beta_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_3args) TEST_CASE(dot_4D_alpha_beta_C_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args) ...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_3args) TEST_CASE(dot_2D_C_test0)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args) ...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
} }
} }
TEST_CASE(matmul_vv_inner_product) TEST_CASE(dot_vv_inner_product)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
} }
} }
TEST_CASE(matmul_vm) TEST_CASE(dot_vm)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm) ...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
} }
} }
TEST_CASE(matmul_mv) TEST_CASE(dot_mv)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv) ...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
} }
} }
TEST_CASE(matmul_mm1) TEST_CASE(dot_mm1)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1) ...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
} }
} }
TEST_CASE(matmul_mm2) TEST_CASE(dot_mm2)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2) ...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
} }
} }
TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 5}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape, a.data());
params["b"] = migraphx::argument(b_shape, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 5, 3}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape0, a.data());
params["b"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
{ {
{ {
......
This diff is collapsed.
This diff is collapsed.
...@@ -160,6 +160,20 @@ TEST_CASE(test_shape_dynamic_compares) ...@@ -160,6 +160,20 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(ss0.str() != ss3.str()); EXPECT(ss0.str() != ss3.str());
} }
TEST_CASE(dynamic_dimension_size_t_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2};
EXPECT(a == 2);
EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0};
EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b);
}
TEST_CASE(test_shape_dynamic_errors) TEST_CASE(test_shape_dynamic_errors)
{ {
using migraphx::shape; using migraphx::shape;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment