Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
03f0e278
Commit
03f0e278
authored
Jan 17, 2023
by
charlie
Browse files
Fix parsing and add test
parent
22012c6d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
7 deletions
+31
-7
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+10
-6
test/onnx/gemm_dyn_bias_test.onnx
test/onnx/gemm_dyn_bias_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+20
-0
No files found.
src/onnx/parse_gemm.cpp
View file @
03f0e278
...
...
@@ -113,15 +113,19 @@ struct parse_gemm : op_parser<parse_gemm>
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_c
->
get_shape
().
type
()
!=
dot_type
)
if
(
not
float_equal
(
beta
,
1.0
f
))
{
beta_c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_c
);
auto
beta_literal
=
info
.
add_literal
(
beta
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
beta_c
);
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
return
dot_ins
;
...
...
test/onnx/gemm_dyn_bias_test.onnx
0 → 100644
View file @
03f0e278
File added
test/onnx/gen_onnx.py
View file @
03f0e278
...
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@
onnx_test
()
def
gemm_dyn_
C_error
():
def
gemm_dyn_
bias_test
():
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
...
...
test/onnx/onnx_test.cpp
View file @
03f0e278
...
...
@@ -2278,6 +2278,26 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_rank_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment