Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b2106be7
"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "f45977008a52baaf97640a0e9b2bbe5ea1c4be34"
Commit
b2106be7
authored
Mar 08, 2019
by
Shucai Xiao
Browse files
clang format
parent
007ea283
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
29 deletions
+26
-29
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+26
-29
No files found.
src/include/migraphx/operators.hpp
View file @
b2106be7
...
...
@@ -838,9 +838,9 @@ struct dot
if
(
b
.
empty
())
return
a
;
if
(
a
.
empty
())
if
(
a
.
empty
())
{
if
(
is_mutli_broadcast
)
if
(
is_mutli_broadcast
)
{
return
b
;
}
...
...
@@ -853,14 +853,13 @@ struct dot
auto
a_size
=
a
.
size
();
auto
b_size
=
b
.
size
();
if
(
is_mutli_broadcast
&&
b_size
>
a_size
)
if
(
is_mutli_broadcast
&&
b_size
>
a_size
)
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
{
...
...
@@ -872,25 +871,25 @@ struct dot
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
else
{
else
{
if
(
a
[
a_size
-
1
-
i
]
==
1
&&
is_mutli_broadcast
)
{
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
}
else
{
if
(
is_mutli_broadcast
)
if
(
is_mutli_broadcast
)
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
}
else
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
}
}
...
...
@@ -924,18 +923,18 @@ struct dot
MIGRAPHX_THROW
(
"DOT: scalar operands are not allowed, use op::mul{} instead"
);
}
auto
a_lens
=
a
.
lens
();
auto
b_lens
=
b
.
lens
();
auto
a_lens
=
a
.
lens
();
auto
b_lens
=
b
.
lens
();
bool
is_a_appended
=
false
;
bool
is_b_appended
=
false
;
if
(
a_lens
.
size
()
==
1
)
if
(
a_lens
.
size
()
==
1
)
{
a_lens
.
insert
(
a_lens
.
begin
(),
1
);
is_a_appended
=
true
;
}
if
(
b_lens
.
size
()
==
1
)
if
(
b_lens
.
size
()
==
1
)
{
b_lens
.
push_back
(
1
);
is_b_appended
=
true
;
...
...
@@ -943,11 +942,10 @@ struct dot
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a
.
lens
())
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a
.
lens
())
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// remove the matrix dims, do multi_broadcast of the shape of the batch
...
...
@@ -964,34 +962,33 @@ struct dot
out_lens
.
push_back
(
out_n
);
// remove the prepended 1, if a is a vector
if
(
is_a_appended
)
if
(
is_a_appended
)
{
out_lens
.
erase
(
out_lens
.
begin
()
+
out_lens
.
size
()
-
2
);
}
// remove the appended 1, if b is a vector
if
(
is_b_appended
)
if
(
is_b_appended
)
{
out_lens
.
pop_back
();
}
// c is unibroadcastable to A * B
if
(
inputs
.
size
()
==
3
)
{
// same type as A and B
check_shapes
{{
inputs
[
0
],
inputs
[
2
]},
*
this
}.
has
(
2
).
same_type
();
if
(
out_lens
.
empty
()
&&
(
!
inputs
[
2
].
scalar
()))
if
(
out_lens
.
empty
()
&&
(
!
inputs
[
2
].
scalar
()))
{
MIGRAPHX_THROW
(
"DOT: C is not broadcastable to A*B (scalar)"
);
}
//check c is broadcastable to A * B
//
check c is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
shape_broadcast
(
out_lens
,
c_lens
,
false
);
}
if
(
out_lens
.
empty
())
if
(
out_lens
.
empty
())
{
return
{
t
};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment