Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
03f0e278
Commit
03f0e278
authored
Jan 17, 2023
by
charlie
Browse files
Fix parsing and add test
parent
22012c6d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
7 deletions
+31
-7
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+10
-6
test/onnx/gemm_dyn_bias_test.onnx
test/onnx/gemm_dyn_bias_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+20
-0
No files found.
src/onnx/parse_gemm.cpp
View file @
03f0e278
...
...
@@ -113,15 +113,19 @@ struct parse_gemm : op_parser<parse_gemm>
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
if
(
not
float_equal
(
beta
,
1.0
f
))
{
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_c
->
get_shape
().
type
()
!=
dot_type
)
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
beta_c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_c
);
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
beta_c
);
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
return
dot_ins
;
...
...
test/onnx/gemm_dyn_bias_test.onnx
0 → 100644
View file @
03f0e278
File added
test/onnx/gen_onnx.py
View file @
03f0e278
...
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@
onnx_test
()
def
gemm_dyn_
C_error
():
def
gemm_dyn_
bias_test
():
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
...
...
test/onnx/onnx_test.cpp
View file @
03f0e278
...
...
@@ -2278,6 +2278,26 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gemm_dyn_bias_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x0
=
mm
->
add_parameter
(
"A"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
8
,
8
},
{
1
,
10
}}});
auto
x1
=
mm
->
add_parameter
(
"B"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
,
7
}});
auto
x2
=
mm
->
add_parameter
(
"C"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
7
}});
auto
x0_t
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
x0
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x0_t
,
x1
);
auto
x2_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
x2
,
dot
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
x2_b
);
mm
->
add_return
({
ret
});
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
1
,
10
};
auto
prog
=
parse_onnx
(
"gemm_dyn_bias_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gemm_rank_error
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"gemm_rank_error.onnx"
);
}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment