Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8a45e79f
"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "3c301efa15c8afc2b7b3153f9c5979ed13effa8d"
Commit
8a45e79f
authored
Apr 03, 2019
by
Shucai Xiao
Browse files
clang format
parent
77212cc1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
7 deletions
+8
-7
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+5
-5
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+3
-2
No files found.
src/include/migraphx/operators.hpp
View file @
8a45e79f
...
...
@@ -839,7 +839,8 @@ struct dot
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dim values mismatch"
);
}
...
...
@@ -854,11 +855,10 @@ struct dot
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
c_lens
)
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
return
{
t
,
out_lens
};
...
...
src/targets/cpu/gemm.cpp
View file @
8a45e79f
...
...
@@ -101,7 +101,8 @@ void migemm_impl(
{
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
if
(
batch_mul
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment