Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e522c7e
"tests/vscode:/vscode.git/clone" did not exist on "1b91856d0eee7b6fb58340e9b54ea2c3d5424311"
Commit
7e522c7e
authored
Jun 21, 2018
by
Paul
Browse files
Formatting
parent
33f53196
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
12 deletions
+13
-12
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+9
-8
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+4
-4
No files found.
src/include/rtg/operators.hpp
View file @
7e522c7e
...
@@ -426,12 +426,13 @@ struct broadcast
...
@@ -426,12 +426,13 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
result
=
inputs
.
at
(
0
);
auto
result
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
1
);
auto
input
=
inputs
.
at
(
1
);
std
::
vector
<
size_t
>
bcast_strides
(
result
.
lens
().
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
result
.
lens
().
size
(),
0
);
if
(
std
::
all_of
(
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
if
(
std
::
all_of
(
result
.
lens
().
cbegin
(),
result
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
{
if
(
axis
!=
0
)
if
(
axis
!=
0
)
RTG_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
RTG_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
...
@@ -439,10 +440,10 @@ struct broadcast
...
@@ -439,10 +440,10 @@ struct broadcast
}
}
else
else
{
{
assert
(
result
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
assert
(
result
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
result
.
lens
().
begin
()
+
axis
))
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
result
.
lens
().
begin
()
+
axis
))
RTG_THROW
(
"when broadcasting success sizes must match"
);
RTG_THROW
(
"when broadcasting success sizes must match"
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
result
.
lens
(),
std
::
move
(
bcast_strides
)};
return
{
t
,
result
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
}
}
...
...
test/cpu_ops_test.cpp
View file @
7e522c7e
...
@@ -91,10 +91,10 @@ void broadcast_test()
...
@@ -91,10 +91,10 @@ void broadcast_test()
p
.
compile
(
rtg
::
cpu
::
cpu_target
{});
p
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
auto
output
=
result
.
get
<
int32_t
>
();
auto
output
=
result
.
get
<
int32_t
>
();
EXPECT
(
output
(
0
,
0
)
==
-
2
);
EXPECT
(
output
(
0
,
0
)
==
-
2
);
EXPECT
(
output
(
0
,
1
)
==
-
2
);
EXPECT
(
output
(
0
,
1
)
==
-
2
);
EXPECT
(
output
(
1
,
0
)
==
-
3
);
EXPECT
(
output
(
1
,
0
)
==
-
3
);
EXPECT
(
output
(
1
,
1
)
==
-
3
);
EXPECT
(
output
(
1
,
1
)
==
-
3
);
}
}
void
add_broadcast_test
()
void
add_broadcast_test
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment