Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c6e4c1c6
Commit
c6e4c1c6
authored
Jan 17, 2023
by
charlie
Browse files
Merge branch 'dyn_onnx_gemm_C' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner
parents
b7bfdab0
03f0e278
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
28 deletions
+45
-28
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+28
-24
test/onnx/gemm_dyn_bias_test.onnx
test/onnx/gemm_dyn_bias_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+16
-3
No files found.
src/onnx/parse_gemm.cpp
View file @
c6e4c1c6
...
...
@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
{
// TODO: support dynamic C input
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
in_arg
)
{
return
in_arg
->
get_shape
().
dynamic
();
}))
if
(
not
float_equal
(
beta
,
0.0
f
))
{
MIGRAPHX_THROW
(
"PARSE_GEMM: C input not handled for dynamic input shapes"
);
}
if
(
not
float_equal
(
beta
,
0.0
f
)
and
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_arg
=
args
[
2
];
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
auto
c_arg
=
args
[
2
];
if
(
dot_ins
->
get_shape
().
dynamic
())
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
}
else
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_c
->
get_shape
().
type
()
!=
dot_type
)
if
(
not
float_equal
(
beta
,
1.0
f
))
{
beta_c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_c
);
auto
beta_literal
=
info
.
add_literal
(
beta
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_c
);
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
return
ret
;
return
dot_ins
;
}
};
...
...
test/onnx/gemm_dyn_
C_error
.onnx
→
test/onnx/gemm_dyn_
bias_test
.onnx
View file @
c6e4c1c6
No preview for this file type
test/onnx/gen_onnx.py
View file @
c6e4c1c6
...
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@
onnx_test
()
def
gemm_dyn_
C_error
():
def
gemm_dyn_
bias_test
():
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
...
...
test/onnx/onnx_test.cpp
View file @
c6e4c1c6
...
...
@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_
C_error
)
TEST_CASE(gemm_dyn_
bias_test
)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_dyn_C_error.onnx", options); }));
options.default_dyn_dim_value = {1, 10};
auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_rank_error)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment