Commit a350dcc2 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added/fixed contiguous test and made transpose test more complete

parent 796d1802
......@@ -398,20 +398,37 @@ void transpose_test()
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2};
p.add_instruction(rtg::transpose{perm}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2};
p.add_instruction(rtg::transpose{perm}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(12);
result.visit([&] (auto output){
std::vector<size_t> new_lens = {1,3,2,2};
std::vector<size_t> new_strides = {12,1,6,3};
EXPECT(bool{output.get_shape().lens() == new_lens});
EXPECT(bool{output.get_shape().strides() == new_strides});
});
result.visit([&] (auto output){
std::vector<size_t> new_lens = {1,3,2,2};
std::vector<size_t> new_strides = {12,1,6,3};
EXPECT(bool{output.get_shape().lens() == new_lens});
EXPECT(bool{output.get_shape().strides() == new_strides});
});
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2};
auto result = p.add_instruction(rtg::transpose{perm}, l);
p.add_instruction(rtg::contiguous{}, result);
p.compile(rtg::cpu::cpu_target{});
auto result2 = p.eval({});
std::vector<float> results_vector(12);
result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(test::verify_range(results_vector, gold));
}
}
void contiguous_test() {
......@@ -426,11 +443,11 @@ void contiguous_test() {
auto result = p.eval({});
std::vector<float> results_vector(12);
result.visit([&] (auto output){
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
});
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(test::verify_range(results_vector, gold));
}
int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment