Commit 4286f502 authored by PanZezhong's avatar PanZezhong
Browse files

issue/92 修改rearrange,添加单测

parent 309878f0
#include "utils_test.h"
int main(int argc, char *argv[]) {
int failed = 0;
failed += test_rearrange();
return failed;
}
#include "utils_test.h"
#include <cstring>
#include <iostream>
#include <numeric>
#include <vector>
void incrementOffset(ptrdiff_t &offset_1, const std::vector<ptrdiff_t> &strides_1, size_t data_size_1,
ptrdiff_t &offset_2, const std::vector<ptrdiff_t> &strides_2, size_t data_size_2,
std::vector<size_t> &counter, const std::vector<size_t> &shape) {
for (ptrdiff_t d = shape.size() - 1; d >= 0; d--) {
counter[d] += 1;
offset_1 += strides_1[d] * data_size_1;
offset_2 += strides_2[d] * data_size_2;
if (counter[d] < shape[d]) {
break;
}
counter[d] = 0;
offset_1 -= shape[d] * strides_1[d] * data_size_1;
offset_2 -= shape[d] * strides_2[d] * data_size_2;
}
}
template <typename T>
size_t check_equal(
const void *a,
const void *b,
const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides_a,
const std::vector<ptrdiff_t> &strides_b) {
auto element_size = sizeof(T);
std::vector<size_t> counter(shape.size(), 0);
ptrdiff_t offset_a = 0;
ptrdiff_t offset_b = 0;
size_t numel = std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies<size_t>());
size_t fails = 0;
for (size_t i = 0; i < numel; i++) {
const T *ptr_a = reinterpret_cast<const T *>((const char *)a + offset_a);
const T *ptr_b = reinterpret_cast<const T *>((const char *)b + offset_b);
if (memcmp(ptr_a, ptr_b, element_size) != 0) {
std::cerr << "Error at " << i << ": " << *ptr_a << " vs " << *ptr_b << std::endl;
fails++;
}
incrementOffset(offset_a, strides_a, element_size, offset_b, strides_b, element_size, counter, shape);
}
return fails;
}
int test_transpose_2d() {
std::vector<size_t> shape = {3, 5};
std::vector<ptrdiff_t> strides_a = {5, 1};
std::vector<ptrdiff_t> strides_b = {1, 3};
auto numel = std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies<size_t>());
std::vector<float> a(numel);
std::vector<float> b(numel);
for (size_t i = 0; i < numel; i++) {
a[i] = i / numel;
}
utils::rearrange(b.data(), a.data(), shape.data(), strides_b.data(), strides_a.data(), 2, sizeof(float));
if (check_equal<float>(a.data(), b.data(), shape, strides_a, strides_b)) {
return 1;
} else {
std::cout << "test_transpose_2d passed" << std::endl;
}
return 0;
}
int test_rearrange() {
return test_transpose_2d();
}
#ifndef __INFINIUTILS_TEST_H__
#define __INFINIUTILS_TEST_H__
#include "../utils.h"
int test_rearrange();
#endif
......@@ -98,15 +98,15 @@ void RearrangeMeta::launch(void *dst_, const void *src_) const {
if (count_ == 1) {
std::memcpy(dst_, src_, unit_);
} else {
for (size_t i = 0; i < idx_strides_[0]; ++i) {
for (size_t i = 0; i < count_; ++i) {
auto dst = reinterpret_cast<char *>(dst_);
auto src = reinterpret_cast<const char *>(src_);
auto rem = i;
for (size_t j = 0; j < ndim_; ++j) {
auto k = rem / idx_strides_[j + 1];
auto k = rem / idx_strides_[j];
dst += k * dst_strides_[j];
src += k * src_strides_[j];
rem %= idx_strides_[j + 1];
rem %= idx_strides_[j];
}
std::memcpy(dst, src, unit_);
}
......
target("infiniutils-test")
set_kind("binary")
add_deps("infini-utils")
on_install(function (target) end)
set_warnings("all", "error")
set_languages("cxx17")
add_files(os.projectdir().."/src/utils-test/*.cc")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment