Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f2c7e9b3
Unverified
Commit
f2c7e9b3
authored
Jan 20, 2023
by
Charlie Lin
Committed by
GitHub
Jan 20, 2023
Browse files
Dynamic onnx gemm bias (#1527)
Adds support for parsing dynamic ONNX gemm bias input C
parent
d309e02f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
28 deletions
+45
-28
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+28
-24
test/onnx/gemm_dyn_bias_test.onnx
test/onnx/gemm_dyn_bias_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+16
-3
No files found.
src/onnx/parse_gemm.cpp
View file @
f2c7e9b3
...
@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
// TODO: support dynamic C input
if
(
not
float_equal
(
beta
,
0.0
f
))
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
in_arg
)
{
return
in_arg
->
get_shape
().
dynamic
();
}))
{
{
MIGRAPHX_THROW
(
"PARSE_GEMM: C input not handled for dynamic input shapes"
);
auto
c_arg
=
args
[
2
];
}
if
(
dot_ins
->
get_shape
().
dynamic
())
if
(
not
float_equal
(
beta
,
0.0
f
)
and
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
}
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
else
auto
c_arg
=
args
[
2
];
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
{
c_arg
=
info
.
add_instruction
(
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
not
float_equal
(
beta
,
1.0
f
))
if
(
beta_c
->
get_shape
().
type
()
!=
dot_type
)
{
{
beta_c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
auto
beta_literal
=
info
.
add_literal
(
beta
);
beta_c
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_c
);
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
}
}
return
dot_ins
;
return
ret
;
}
}
};
};
...
...
test/onnx/gemm_dyn_bias_test.onnx
0 → 100644
View file @
f2c7e9b3
File added
test/onnx/gen_onnx.py
View file @
f2c7e9b3
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@
onnx_test
()
@
onnx_test
()
def
gemm_dyn_
C_error
():
def
gemm_dyn_
bias_test
():
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
...
...
test/onnx/onnx_test.cpp
View file @
f2c7e9b3
...
@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test)
...
@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gemm_dyn_
C_error
)
TEST_CASE
(
gemm_dyn_
bias_test
)
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x0
=
mm
->
add_parameter
(
"A"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
8
,
8
},
{
1
,
10
}}});
auto
x1
=
mm
->
add_parameter
(
"B"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
,
7
}});
auto
x2
=
mm
->
add_parameter
(
"C"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
7
}});
auto
x0_t
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
x0
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x0_t
,
x1
);
auto
x2_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
x2
,
dot
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
x2_b
);
mm
->
add_return
({
ret
});
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
1
,
4
,
0
};
options
.
default_dyn_dim_value
=
{
1
,
10
};
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"gemm_dyn_C_error.onnx"
,
options
);
}));
auto
prog
=
parse_onnx
(
"gemm_dyn_bias_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gemm_rank_error
)
TEST_CASE
(
gemm_rank_error
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment