Commit 84290a10 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add gpu tests for dot operator.

parent 1f679721
...@@ -886,6 +886,118 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2> ...@@ -886,6 +886,118 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
} }
}; };
struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -902,6 +1014,86 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> ...@@ -902,6 +1014,86 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
} }
}; };
struct gemm_2args_vv : verify_program<gemm_2args_vv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mv : verify_program<gemm_2args_mv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_vm : verify_program<gemm_2args_vm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_multi_3args : verify_program<gemm_multi_3args> struct gemm_multi_3args : verify_program<gemm_multi_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -922,6 +1114,86 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args> ...@@ -922,6 +1114,86 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
} }
}; };
struct gemm_multi_3args_c0 : verify_program<gemm_multi_3args_c0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c5 : verify_program<gemm_multi_3args_c5>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c21 : verify_program<gemm_multi_3args_c21>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 1}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment