Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b2106be7
Commit
b2106be7
authored
Mar 08, 2019
by
Shucai Xiao
Browse files
clang format
parent
007ea283
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
29 deletions
+26
-29
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+26
-29
No files found.
src/include/migraphx/operators.hpp
View file @
b2106be7
...
...
@@ -838,9 +838,9 @@ struct dot
if
(
b
.
empty
())
return
a
;
if
(
a
.
empty
())
if
(
a
.
empty
())
{
if
(
is_mutli_broadcast
)
if
(
is_mutli_broadcast
)
{
return
b
;
}
...
...
@@ -853,11 +853,10 @@ struct dot
auto
a_size
=
a
.
size
();
auto
b_size
=
b
.
size
();
if
(
is_mutli_broadcast
&&
b_size
>
a_size
)
if
(
is_mutli_broadcast
&&
b_size
>
a_size
)
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
...
...
@@ -880,11 +879,11 @@ struct dot
}
else
{
if
(
is_mutli_broadcast
)
if
(
is_mutli_broadcast
)
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
}
else
{
...
...
@@ -929,13 +928,13 @@ struct dot
bool
is_a_appended
=
false
;
bool
is_b_appended
=
false
;
if
(
a_lens
.
size
()
==
1
)
if
(
a_lens
.
size
()
==
1
)
{
a_lens
.
insert
(
a_lens
.
begin
(),
1
);
is_a_appended
=
true
;
}
if
(
b_lens
.
size
()
==
1
)
if
(
b_lens
.
size
()
==
1
)
{
b_lens
.
push_back
(
1
);
is_b_appended
=
true
;
...
...
@@ -943,11 +942,10 @@ struct dot
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a
.
lens
())
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a
.
lens
())
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// remove the matrix dims, do multi_broadcast of the shape of the batch
...
...
@@ -964,34 +962,33 @@ struct dot
out_lens
.
push_back
(
out_n
);
// remove the prepended 1, if a is a vector
if
(
is_a_appended
)
if
(
is_a_appended
)
{
out_lens
.
erase
(
out_lens
.
begin
()
+
out_lens
.
size
()
-
2
);
}
// remove the appended 1, if b is a vector
if
(
is_b_appended
)
if
(
is_b_appended
)
{
out_lens
.
pop_back
();
}
// c is unibroadcastable to A * B
if
(
inputs
.
size
()
==
3
)
{
// same type as A and B
check_shapes
{{
inputs
[
0
],
inputs
[
2
]},
*
this
}.
has
(
2
).
same_type
();
if
(
out_lens
.
empty
()
&&
(
!
inputs
[
2
].
scalar
()))
if
(
out_lens
.
empty
()
&&
(
!
inputs
[
2
].
scalar
()))
{
MIGRAPHX_THROW
(
"DOT: C is not broadcastable to A*B (scalar)"
);
}
//check c is broadcastable to A * B
//
check c is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
shape_broadcast
(
out_lens
,
c_lens
,
false
);
}
if
(
out_lens
.
empty
())
if
(
out_lens
.
empty
())
{
return
{
t
};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment