Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
850ff8a1
Commit
850ff8a1
authored
Nov 16, 2022
by
charlie
Browse files
Fix onnx tests
parent
ca9019e2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
36 deletions
+7
-36
test/onnx/gemm_dyn_C_error.onnx
test/onnx/gemm_dyn_C_error.onnx
+0
-0
test/onnx/gemm_dyn_outer_test.onnx
test/onnx/gemm_dyn_outer_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+3
-5
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+4
-31
No files found.
test/onnx/gemm_dyn_C_error.onnx
View file @
850ff8a1
No preview for this file type
test/onnx/gemm_dyn_outer_test.onnx
View file @
850ff8a1
No preview for this file type
test/onnx/gen_onnx.py
View file @
850ff8a1
...
...
@@ -2074,18 +2074,16 @@ def gemm_dyn_inner_test():
def
gemm_dyn_outer_test
():
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
5
,
None
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
11
,
5
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[])
Y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
None
,
11
])
node
=
onnx
.
helper
.
make_node
(
'Gemm'
,
inputs
=
[
'A'
,
'B'
,
'C'
],
inputs
=
[
'A'
,
'B'
],
outputs
=
[
'Y'
],
alpha
=
2.0
,
beta
=
2.0
,
transA
=
1
,
transB
=
1
)
return
([
node
],
[
A
,
B
,
C
],
[
Y
])
return
([
node
],
[
A
,
B
],
[
Y
])
@
onnx_test
...
...
@@ -2102,7 +2100,7 @@ def gemm_dyn_C_error():
beta
=
1.0
,
transA
=
1
)
return
([
node
],
[
A
,
B
],
[
Y
])
return
([
node
],
[
A
,
B
,
C
],
[
Y
])
@
onnx_test
...
...
test/onnx/onnx_test.cpp
View file @
850ff8a1
...
...
@@ -2101,9 +2101,9 @@ TEST_CASE(gemm_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("
1
", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("
2
", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("
3
", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto l0 = mm->add_parameter("
A
", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("
B
", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("
C
", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
...
...
@@ -2155,22 +2155,13 @@ TEST_CASE(gemm_dyn_outer_test)
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
mm->add_return({ret});
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7};
...
...
@@ -6031,24 +6022,6 @@ TEST_CASE(transpose_test)
EXPECT(p == prog);
}
TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}});
std::vector<int64_t> perm{0, 3, 1, 2};
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({t0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(topk_attrk_test)
{
migraphx::program p;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment