#include "utils_test.h" #include #include #include #include void incrementOffset(ptrdiff_t &offset_1, const std::vector &strides_1, size_t data_size_1, ptrdiff_t &offset_2, const std::vector &strides_2, size_t data_size_2, std::vector &counter, const std::vector &shape) { for (ptrdiff_t d = shape.size() - 1; d >= 0; d--) { counter[d] += 1; offset_1 += strides_1[d] * data_size_1; offset_2 += strides_2[d] * data_size_2; if (counter[d] < shape[d]) { break; } counter[d] = 0; offset_1 -= shape[d] * strides_1[d] * data_size_1; offset_2 -= shape[d] * strides_2[d] * data_size_2; } } template size_t check_equal( const void *a, const void *b, const std::vector &shape, const std::vector &strides_a, const std::vector &strides_b) { auto element_size = sizeof(T); std::vector counter(shape.size(), 0); ptrdiff_t offset_a = 0; ptrdiff_t offset_b = 0; size_t numel = std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies()); size_t fails = 0; for (size_t i = 0; i < numel; i++) { const T *ptr_a = reinterpret_cast((const char *)a + offset_a); const T *ptr_b = reinterpret_cast((const char *)b + offset_b); if (memcmp(ptr_a, ptr_b, element_size) != 0) { std::cerr << "Error at " << i << ": " << *ptr_a << " vs " << *ptr_b << std::endl; fails++; } incrementOffset(offset_a, strides_a, element_size, offset_b, strides_b, element_size, counter, shape); } return fails; } int test_transpose_2d() { std::vector shape = {3, 5}; std::vector strides_a = {5, 1}; std::vector strides_b = {1, 3}; auto numel = std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies()); std::vector a(numel); std::vector b(numel); for (size_t i = 0; i < numel; i++) { a[i] = i / numel; } utils::rearrange(b.data(), a.data(), shape.data(), strides_b.data(), strides_a.data(), 2, sizeof(float)); if (check_equal(a.data(), b.data(), shape, strides_a, strides_b)) { return 1; } else { std::cout << "test_transpose_2d passed" << std::endl; } return 0; } int test_rearrange() { return test_transpose_2d(); }