Commit 3a0d6510 authored by YdrMaster's avatar YdrMaster
Browse files

issue/152/fix: 改正 rearrange 元信息填充


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 203de1a4
...@@ -66,5 +66,7 @@ int test_transpose_any(size_t index, std::vector<size_t> shape, std::vector<ptrd ...@@ -66,5 +66,7 @@ int test_transpose_any(size_t index, std::vector<size_t> shape, std::vector<ptrd
int test_rearrange() { int test_rearrange() {
return test_transpose_any(1, {3, 5}, {5, 1}, {1, 3}) return test_transpose_any(1, {3, 5}, {5, 1}, {1, 3})
+ test_transpose_any(2, {1, 2048}, {2048, 1}, {2048, 1}); + test_transpose_any(2, {1, 2048}, {2048, 1}, {2048, 1})
+ test_transpose_any(3, {2, 2, 2, 4}, {16, 8, 1, 2}, {16, 8, 4, 1})
+ test_transpose_any(4, {2, 2, 2, 2, 4}, {32, 16, 8, 1, 2}, {32, 16, 8, 4, 1});
} }
...@@ -70,15 +70,18 @@ Result<RearrangeMeta> RearrangeMeta::create( ...@@ -70,15 +70,18 @@ Result<RearrangeMeta> RearrangeMeta::create(
ndim -= 1; ndim -= 1;
} }
} }
dims.resize(ndim);
// 填写序号步长、输入步长和输出步长 // 填写序号步长、输入步长和输出步长
std::vector<ptrdiff_t> meta(2 + ndim * 3); std::vector<ptrdiff_t> meta(2 + ndim * 3);
meta[0] = unit; meta[0] = unit;
meta[1 + ndim] = 1; meta[1 + ndim] = 1;
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0, j = 0; i < ndim; ++i, ++j) {
meta[1 + i] = dims[i].len; // filter dim.len != 1
meta[1 + 1 + ndim + i] = dims[i].dst; while (dims[j].len == 1) {
meta[1 + 1 + ndim * 2 + i] = dims[i].src; ++j;
}
meta[1 + i] = dims[j].len;
meta[1 + i + 1 + ndim] = dims[j].dst;
meta[1 + i + 1 + ndim * 2] = dims[j].src;
} }
for (ptrdiff_t i = ndim; i > 0; --i) { for (ptrdiff_t i = ndim; i > 0; --i) {
meta[1 + i - 1] *= meta[1 + i]; meta[1 + i - 1] *= meta[1 + i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment