Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e7c17bd6
Commit
e7c17bd6
authored
Dec 01, 2022
by
charlie
Browse files
Tidy style fix
parent
801a349c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
31 deletions
+35
-31
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+35
-31
No files found.
src/onnx/parse_gemm.cpp
View file @
e7c17bd6
...
...
@@ -39,19 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
A
=
args
[
0
];
auto
B
=
args
[
1
];
if
(
A
->
get_shape
().
ndim
()
!=
2
or
B
->
get_shape
().
ndim
()
!=
2
)
auto
a_arg
=
args
[
0
];
auto
b_arg
=
args
[
1
];
if
(
a_arg
->
get_shape
().
ndim
()
!=
2
or
b_arg
->
get_shape
().
ndim
()
!=
2
)
{
MIGRAPHX_THROW
(
"PARSE_GEMM: A and B should be rank 2, A is rank "
+
std
::
to_string
(
A
->
get_shape
().
ndim
())
+
"B is rank "
+
std
::
to_string
(
B
->
get_shape
().
ndim
()));
std
::
to_string
(
a_arg
->
get_shape
().
ndim
())
+
"B is rank "
+
std
::
to_string
(
b_arg
->
get_shape
().
ndim
()));
}
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
transa
=
false
;
bool
transb
=
false
;
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
trans
_
a
=
false
;
bool
trans
_
b
=
false
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
{
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
...
...
@@ -62,11 +62,11 @@ struct parse_gemm : op_parser<parse_gemm>
}
if
(
contains
(
info
.
attributes
,
"transA"
))
{
transa
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
trans
_
a
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
if
(
contains
(
info
.
attributes
,
"transB"
))
{
transb
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
trans
_
b
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
(
2
);
...
...
@@ -74,24 +74,28 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
dot_type
=
A
->
get_shape
().
type
();
auto
dot_type
=
a_arg
->
get_shape
().
type
();
if
(
alpha
!=
1.0
f
)
{
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
A
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
A
);
a_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
a_arg
);
if
(
A
->
get_shape
().
type
()
!=
dot_type
)
if
(
a_arg
->
get_shape
().
type
()
!=
dot_type
)
{
A
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
A
);
a_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
a_arg
);
}
}
A
=
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
A
)
:
A
;
B
=
(
transb
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
a_arg
=
(
trans_a
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
a_arg
)
:
a_arg
;
b_arg
=
(
trans_b
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
A
,
B
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
{
...
...
@@ -104,24 +108,24 @@ struct parse_gemm : op_parser<parse_gemm>
}
if
(
not
float_equal
(
beta
,
0.0
f
)
and
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
auto
out_lens
=
A
->
get_shape
().
lens
();
out_lens
.
back
()
=
B
->
get_shape
().
lens
().
back
();
auto
C
=
args
[
2
];
auto
C
_lens
=
C
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
C
_lens
.
begin
(),
C
_lens
.
end
()))
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_arg
=
args
[
2
];
auto
c
_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c
_lens
.
begin
(),
c
_lens
.
end
()))
{
C
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_
C
=
info
.
add_broadcastable_binary_op
(
"mul"
,
C
,
beta_literal
);
if
(
beta_
C
->
get_shape
().
type
()
!=
dot_type
)
auto
beta_
c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_
c
->
get_shape
().
type
()
!=
dot_type
)
{
beta_
C
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_
C
);
beta_
c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_
c
);
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_
C
);
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_
c
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment