Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c6e4c1c6
Commit
c6e4c1c6
authored
Jan 17, 2023
by
charlie
Browse files
Merge branch 'dyn_onnx_gemm_C' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner
parents
b7bfdab0
03f0e278
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
28 deletions
+45
-28
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+28
-24
test/onnx/gemm_dyn_bias_test.onnx
test/onnx/gemm_dyn_bias_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+16
-3
No files found.
src/onnx/parse_gemm.cpp
View file @
c6e4c1c6
...
@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
// TODO: support dynamic C input
if
(
not
float_equal
(
beta
,
0.0
f
))
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
in_arg
)
{
return
in_arg
->
get_shape
().
dynamic
();
}))
{
{
MIGRAPHX_THROW
(
"PARSE_GEMM: C input not handled for dynamic input shapes"
);
auto
c_arg
=
args
[
2
];
if
(
dot_ins
->
get_shape
().
dynamic
())
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
}
}
if
(
not
float_equal
(
beta
,
0.0
f
)
and
args
[
2
]
->
get_shape
().
elements
()
>
0
)
else
{
{
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_arg
=
args
[
2
];
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
{
c_arg
=
info
.
add_instruction
(
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
}
if
(
not
float_equal
(
beta
,
1.0
f
))
{
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_c
->
get_shape
().
type
()
!=
dot_type
)
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
{
beta_c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
=
info
.
add_instruction
(
beta_c
);
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_c
);
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
}
}
return
dot_ins
;
return
ret
;
}
}
};
};
...
...
test/onnx/gemm_dyn_
C_error
.onnx
→
test/onnx/gemm_dyn_
bias_test
.onnx
View file @
c6e4c1c6
No preview for this file type
test/onnx/gen_onnx.py
View file @
c6e4c1c6
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
...
@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@
onnx_test
()
@
onnx_test
()
def
gemm_dyn_
C_error
():
def
gemm_dyn_
bias_test
():
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
A
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
FLOAT
,
[
8
,
None
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
B
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
FLOAT
,
[
8
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
C
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
FLOAT
,
[
1
,
7
])
...
...
test/onnx/onnx_test.cpp
View file @
c6e4c1c6
...
@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test)
...
@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gemm_dyn_
C_error
)
TEST_CASE
(
gemm_dyn_
bias_test
)
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x0
=
mm
->
add_parameter
(
"A"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
8
,
8
},
{
1
,
10
}}});
auto
x1
=
mm
->
add_parameter
(
"B"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
,
7
}});
auto
x2
=
mm
->
add_parameter
(
"C"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
7
}});
auto
x0_t
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
x0
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x0_t
,
x1
);
auto
x2_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
x2
,
dot
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
x2_b
);
mm
->
add_return
({
ret
});
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
1
,
4
,
0
};
options
.
default_dyn_dim_value
=
{
1
,
10
};
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"gemm_dyn_C_error.onnx"
,
options
);
}));
auto
prog
=
parse_onnx
(
"gemm_dyn_bias_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gemm_rank_error
)
TEST_CASE
(
gemm_rank_error
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment