"vscode:/vscode.git/clone" did not exist on "76aa5e8186f23e3097608894a4bc678d1300e180"
Unverified Commit 913ae362 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into optimize

parents f1e16656 b8c8d09b
This diff is collapsed.
......@@ -35,7 +35,7 @@
#include <migraphx/half.hpp>
template <class T>
void matmul_test()
void dot_2d_test()
{
migraphx::program p;
......@@ -82,11 +82,11 @@ void matmul_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test<float>)
TEST_CASE_REGISTER(matmul_test<double>)
TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>)
template <class T>
void matmul_test_ex()
void dot_4d_test()
{
migraphx::program p;
......@@ -133,10 +133,10 @@ void matmul_test_ex()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test_ex<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>)
TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>)
TEST_CASE(matmul_mutli_dim_2)
TEST_CASE(dot_3D_test)
{
migraphx::program p;
......@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_beta0)
TEST_CASE(dot_3D_C_test0)
{
migraphx::program p;
......@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_beta_0)
TEST_CASE(dot_3D_C_test1)
{
migraphx::program p;
......@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(matmul_mutli_dim_2_3)
TEST_CASE(dot_4D_test1)
{
migraphx::program p;
......@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
TEST_CASE(dot_4D_alpha_beta_test)
{
migraphx::program p;
......@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_3args)
TEST_CASE(dot_4D_alpha_beta_C_test)
{
migraphx::program p;
......@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_3args)
TEST_CASE(dot_2D_C_test0)
{
{
migraphx::program p;
......@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
}
}
TEST_CASE(matmul_vv_inner_product)
TEST_CASE(dot_vv_inner_product)
{
{
migraphx::program p;
......@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
}
}
TEST_CASE(matmul_vm)
TEST_CASE(dot_vm)
{
{
migraphx::program p;
......@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
}
}
TEST_CASE(matmul_mv)
TEST_CASE(dot_mv)
{
{
migraphx::program p;
......@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
}
}
TEST_CASE(matmul_mm1)
TEST_CASE(dot_mm1)
{
{
migraphx::program p;
......@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
}
}
TEST_CASE(matmul_mm2)
TEST_CASE(dot_mm2)
{
{
migraphx::program p;
......@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 5}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape, a.data());
params["b"] = migraphx::argument(b_shape, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 5, 3}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape0, a.data());
params["b"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4)
{
{
......
This diff is collapsed.
This diff is collapsed.
......@@ -160,6 +160,20 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(ss0.str() != ss3.str());
}
TEST_CASE(dynamic_dimension_size_t_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2};
EXPECT(a == 2);
EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0};
EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b);
}
TEST_CASE(test_shape_dynamic_errors)
{
using migraphx::shape;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment