Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
92051ab8
"test/ref_ops_test.cpp" did not exist on "5d4fcb59a56dbf647c9ac0a86a62dec1f039d993"
Commit
92051ab8
authored
Jun 20, 2018
by
Scott Thornton
Browse files
Fixed broadcast and added tests
parent
21e88916
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
31 deletions
+46
-31
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+2
-2
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+44
-29
No files found.
src/include/rtg/operators.hpp
View file @
92051ab8
...
...
@@ -431,8 +431,8 @@ struct broadcast
auto
shape1
=
inputs
.
at
(
1
);
auto
shape0_lens
=
shape0
.
lens
();
auto
shape1_lens
=
shape1
.
lens
();
const
auto
&
shape0_strides
=
shape0
.
len
s
();
auto
shape1_strides
=
shape1
.
len
s
();
const
auto
&
shape0_strides
=
shape0
.
stride
s
();
auto
shape1_strides
=
shape1
.
stride
s
();
if
(
std
::
all_of
(
shape0_lens
.
cbegin
(),
shape1_lens
.
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
if
(
axis
!=
0
)
...
...
test/cpu_ops_test.cpp
View file @
92051ab8
...
...
@@ -6,33 +6,6 @@
#include "test.hpp"
#include "verify.hpp"
void
fred
()
{
size_t
axis
=
1
;
rtg
::
shape
shape0
{
rtg
::
shape
::
float_type
,
{
2
,
4
,
3
,
4
}};
rtg
::
shape
shape1
{
rtg
::
shape
::
float_type
,
{
4
,
3
}};
std
::
vector
<
size_t
>
shape0_lens
=
shape0
.
lens
();
std
::
vector
<
size_t
>
shape1_lens
=
shape1
.
lens
();
const
std
::
vector
<
size_t
>&
shape0_strides
=
shape0
.
strides
();
std
::
vector
<
size_t
>
shape1_strides
=
shape1
.
strides
();
for
(
size_t
i
=
0
;
i
<
shape1
.
lens
().
size
();
i
++
)
{
assert
(
shape0_lens
[
i
+
axis
]
==
shape1_lens
[
i
]);
}
std
::
vector
<
size_t
>
bcast_shape_lens
=
shape0_lens
;
std
::
vector
<
size_t
>
bcast_shape_strides
(
bcast_shape_lens
.
size
(),
0
);
for
(
size_t
i
=
0
;
i
<
shape1_strides
.
size
();
i
++
)
{
bcast_shape_strides
[
i
+
axis
]
=
shape1_strides
[
i
];
}
for
(
auto
x
:
bcast_shape_lens
)
std
::
cout
<<
x
<<
" "
;
std
::
cout
<<
"
\n
"
;
for
(
auto
x
:
bcast_shape_strides
)
std
::
cout
<<
x
<<
" "
;
std
::
cout
<<
"
\n
"
;
}
void
exp_test
()
{
rtg
::
program
p
;
...
...
@@ -104,6 +77,48 @@ void add_test()
EXPECT
(
test
::
verify_range
(
results_vector
,
gold
));
}
void
broadcast_test
()
{
rtg
::
program
p
;
rtg
::
shape
a_shape
{
rtg
::
shape
::
int32_type
,
{
2
,
2
}};
std
::
vector
<
int32_t
>
a_data
{
0
,
0
,
0
,
0
};
rtg
::
shape
b_shape
{
rtg
::
shape
::
int32_type
,
{
2
}};
std
::
vector
<
int32_t
>
b_data
{
-
2
,
-
3
};
uint64_t
axis
=
0
;
auto
l1
=
p
.
add_literal
(
rtg
::
literal
{
a_shape
,
a_data
});
auto
l2
=
p
.
add_literal
(
rtg
::
literal
{
b_shape
,
b_data
});
p
.
add_instruction
(
rtg
::
broadcast
{
axis
},
l1
,
l2
);
p
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int32_t
>
results_vector
(
4
);
// result.visit([&](auto output) {
// EXPECT(output(0,0) == -2);
// EXPECT(output(0,1) == -2);
// EXPECT(output(1,0) == -3);
// EXPECT(output(1,1) == -3);
// });
}
void
add_broadcast_test
()
{
rtg
::
program
p
;
rtg
::
shape
a_shape
{
rtg
::
shape
::
float_type
,
{
2
,
2
,
3
}};
std
::
vector
<
float
>
a_data
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
rtg
::
shape
b_shape
{
rtg
::
shape
::
float_type
,
{
2
,
2
}};
std
::
vector
<
float
>
b_data
{
0
,
-
1
,
-
2
,
-
3
};
uint64_t
axis
=
0
;
auto
l1
=
p
.
add_literal
(
rtg
::
literal
{
a_shape
,
a_data
});
auto
l2
=
p
.
add_literal
(
rtg
::
literal
{
b_shape
,
b_data
});
auto
l3
=
p
.
add_instruction
(
rtg
::
broadcast
{
axis
},
l1
,
l2
);
p
.
add_instruction
(
rtg
::
add
{},
l1
,
l3
);
p
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
EXPECT
(
test
::
verify_range
(
results_vector
,
gold
));
}
void
sub_test
()
{
rtg
::
program
p
;
...
...
@@ -189,7 +204,6 @@ void reshape_test()
}
}
// std::cout << std::abs(results_vector[i]-gold[i]) << std::endl;
void
gemm_test
()
{
rtg
::
program
p
;
...
...
@@ -538,12 +552,13 @@ void contiguous_test()
int
main
()
{
fred
();
exp_test
();
sin_test
();
cos_test
();
tan_test
();
add_test
();
broadcast_test
();
add_broadcast_test
();
sub_test
();
mul_test
();
gemm_test
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment