Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e522c7e
Commit
7e522c7e
authored
Jun 21, 2018
by
Paul
Browse files
Formatting
parent
33f53196
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
12 deletions
+13
-12
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+9
-8
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+4
-4
No files found.
src/include/rtg/operators.hpp
View file @
7e522c7e
...
...
@@ -431,7 +431,8 @@ struct broadcast
auto
input
=
inputs
.
at
(
1
);
std
::
vector
<
size_t
>
bcast_strides
(
result
.
lens
().
size
(),
0
);
if
(
std
::
all_of
(
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
if
(
std
::
all_of
(
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
if
(
axis
!=
0
)
RTG_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
...
...
@@ -439,10 +440,10 @@ struct broadcast
}
else
{
assert
(
result
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
result
.
lens
().
begin
()
+
axis
))
assert
(
result
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
result
.
lens
().
begin
()
+
axis
))
RTG_THROW
(
"when broadcasting success sizes must match"
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
result
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
...
...
test/cpu_ops_test.cpp
View file @
7e522c7e
...
...
@@ -91,10 +91,10 @@ void broadcast_test()
p
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
auto
output
=
result
.
get
<
int32_t
>
();
EXPECT
(
output
(
0
,
0
)
==
-
2
);
EXPECT
(
output
(
0
,
1
)
==
-
2
);
EXPECT
(
output
(
1
,
0
)
==
-
3
);
EXPECT
(
output
(
1
,
1
)
==
-
3
);
EXPECT
(
output
(
0
,
0
)
==
-
2
);
EXPECT
(
output
(
0
,
1
)
==
-
2
);
EXPECT
(
output
(
1
,
0
)
==
-
3
);
EXPECT
(
output
(
1
,
1
)
==
-
3
);
}
void
add_broadcast_test
()
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment