Commit 92051ab8 authored by Scott Thornton's avatar Scott Thornton
Browse files

Fixed broadcast and added tests

parent 21e88916
......@@ -431,8 +431,8 @@ struct broadcast
auto shape1 = inputs.at(1);
auto shape0_lens = shape0.lens();
auto shape1_lens = shape1.lens();
const auto& shape0_strides = shape0.lens();
auto shape1_strides = shape1.lens();
const auto& shape0_strides = shape0.strides();
auto shape1_strides = shape1.strides();
if(std::all_of(shape0_lens.cbegin(), shape1_lens.cend(), [&](auto x) { return x == 1; }))
{
if(axis != 0)
......
......@@ -6,33 +6,6 @@
#include "test.hpp"
#include "verify.hpp"
void fred()
{
size_t axis = 1;
rtg::shape shape0{rtg::shape::float_type, {2, 4, 3, 4}};
rtg::shape shape1{rtg::shape::float_type, {4, 3}};
std::vector<size_t> shape0_lens = shape0.lens();
std::vector<size_t> shape1_lens = shape1.lens();
const std::vector<size_t>& shape0_strides = shape0.strides();
std::vector<size_t> shape1_strides = shape1.strides();
for(size_t i = 0; i < shape1.lens().size(); i++)
{
assert(shape0_lens[i + axis] == shape1_lens[i]);
}
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for(size_t i = 0; i < shape1_strides.size(); i++)
{
bcast_shape_strides[i + axis] = shape1_strides[i];
}
for(auto x : bcast_shape_lens)
std::cout << x << " ";
std::cout << "\n";
for(auto x : bcast_shape_strides)
std::cout << x << " ";
std::cout << "\n";
}
void exp_test()
{
rtg::program p;
......@@ -104,6 +77,48 @@ void add_test()
EXPECT(test::verify_range(results_vector, gold));
}
void broadcast_test()
{
rtg::program p;
rtg::shape a_shape{rtg::shape::int32_type, {2,2}};
std::vector<int32_t> a_data{0,0,0,0};
rtg::shape b_shape{rtg::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2,-3};
uint64_t axis = 0;
auto l1 = p.add_literal(rtg::literal{a_shape, a_data});
auto l2 = p.add_literal(rtg::literal{b_shape, b_data});
p.add_instruction(rtg::broadcast{axis}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int32_t> results_vector(4);
// result.visit([&](auto output) {
// EXPECT(output(0,0) == -2);
// EXPECT(output(0,1) == -2);
// EXPECT(output(1,0) == -3);
// EXPECT(output(1,1) == -3);
// });
}
void add_broadcast_test()
{
rtg::program p;
rtg::shape a_shape{rtg::shape::float_type, {2,2,3}};
std::vector<float> a_data{0,1,2,3,4,5,6,7,8,9,10,11};
rtg::shape b_shape{rtg::shape::float_type, {2,2}};
std::vector<float> b_data{0,-1,-2,-3};
uint64_t axis = 0;
auto l1 = p.add_literal(rtg::literal{a_shape, a_data});
auto l2 = p.add_literal(rtg::literal{b_shape, b_data});
auto l3 = p.add_instruction(rtg::broadcast{axis}, l1, l2);
p.add_instruction(rtg::add{}, l1, l3);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0,1,2,2,3,4,4,5,6,6,7,8};
EXPECT(test::verify_range(results_vector, gold));
}
void sub_test()
{
rtg::program p;
......@@ -189,7 +204,6 @@ void reshape_test()
}
}
// std::cout << std::abs(results_vector[i]-gold[i]) << std::endl;
void gemm_test()
{
rtg::program p;
......@@ -538,12 +552,13 @@ void contiguous_test()
int main()
{
fred();
exp_test();
sin_test();
cos_test();
tan_test();
add_test();
broadcast_test();
add_broadcast_test();
sub_test();
mul_test();
gemm_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment