rearrange.cc 3.93 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
#include "rearrange.h"
YdrMaster's avatar
YdrMaster committed
2
3
4
5
6
7
8
#include "check.h"
#include <algorithm>
#include <cstring>
#include <vector>

namespace utils {

9
10
11
12
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
    : _meta(std::move(meta)) {}

std::optional<RearrangeMeta> RearrangeMeta::create(
YdrMaster's avatar
YdrMaster committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    const size_t *shape,
    const ptrdiff_t *dst_strides_,
    const ptrdiff_t *src_strides_,
    size_t ndim,
    size_t element_size) {
    struct Dim {
        size_t len;
        ptrdiff_t dst, src;
    };

    std::vector<Dim> dims;
    for (size_t i = 0; i < ndim; ++i) {
        // 剔除初始的 1 长维度
        if (shape[i] != 1) {
            auto sd = dst_strides_[i], ss = src_strides_[i];
            // assert (sd != 0)
            dims.push_back(Dim{shape[i], sd, ss});
        }
    }
    // 排序
    std::sort(dims.begin(), dims.end(), [](const Dim &a, const Dim &b) {
        if (std::abs(a.dst) == std::abs(b.dst)) {
            if (std::abs(a.src) == std::abs(b.src)) {
                return a.len < b.len;
            }
            return std::abs(a.src) > std::abs(b.src);
        }
        return std::abs(a.dst) > std::abs(b.dst);
    });
    // # 合并连续维度
    ptrdiff_t unit = element_size;
    // ## 合并末尾连续维度到 unit
    for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
        if (it->dst == unit && it->src == unit) {
            unit *= it->len;
            ndim -= 1;
        } else {
            break;
        }
    }
    // ## 合并任意连续维度
    for (size_t i = ndim - 1; i > 0; --i) {
        auto &f = dims[i - 1];
        auto &b = dims[i];
        ptrdiff_t len = b.len;
        if (b.dst * len == f.dst && b.src * len == f.src) {
            f = Dim{b.len * f.len, b.dst, b.src};
            b = Dim{1, 0, 0};
            ndim -= 1;
        }
    }
    dims.resize(ndim);
    // 填写序号步长、输入步长和输出步长
66
67
68
    std::vector<ptrdiff_t> meta(2 + ndim * 3);
    meta[0] = unit;
    meta[1 + ndim] = 1;
YdrMaster's avatar
YdrMaster committed
69
    for (size_t i = 0; i < ndim; ++i) {
70
71
72
        meta[1 + i] = dims[i].len;
        meta[1 + 1 + ndim + i] = dims[i].dst;
        meta[1 + 1 + ndim * 2 + i] = dims[i].src;
YdrMaster's avatar
YdrMaster committed
73
    }
74
75
    for (ptrdiff_t i = ndim; i > 0; --i) {
        meta[1 + i - 1] *= meta[1 + i];
YdrMaster's avatar
YdrMaster committed
76
    }
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    return {RearrangeMeta(std::move(meta))};
}

size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
size_t RearrangeMeta::unit() const { return _meta[0]; }
size_t RearrangeMeta::count() const { return _meta[1]; }

const ptrdiff_t *RearrangeMeta::idx_strides() const { return _meta.data() + 2; }
const ptrdiff_t *RearrangeMeta::dst_strides() const { return idx_strides() + ndim(); }
const ptrdiff_t *RearrangeMeta::src_strides() const { return dst_strides() + ndim(); }

void RearrangeMeta::launch(void *dst_, const void *src_) const {
    auto const ndim_ = ndim();
    auto const count_ = count();
    auto const unit_ = unit();
    auto const idx_strides_ = idx_strides();
    auto const dst_strides_ = dst_strides();
    auto const src_strides_ = src_strides();
YdrMaster's avatar
YdrMaster committed
95
    // 执行 rearrange
96
97
    if (count_ == 1) {
        std::memcpy(dst_, src_, unit_);
YdrMaster's avatar
YdrMaster committed
98
    } else {
99
        for (size_t i = 0; i < idx_strides_[0]; ++i) {
YdrMaster's avatar
YdrMaster committed
100
101
            auto dst = reinterpret_cast<char *>(dst_);
            auto src = reinterpret_cast<const char *>(src_);
102
103
104
105
106
107
            auto rem = i;
            for (size_t j = 0; j < ndim_; ++j) {
                auto k = rem / idx_strides_[j + 1];
                dst += k * dst_strides_[j];
                src += k * src_strides_[j];
                rem %= idx_strides_[j + 1];
YdrMaster's avatar
YdrMaster committed
108
            }
109
            std::memcpy(dst, src, unit_);
YdrMaster's avatar
YdrMaster committed
110
111
112
113
        }
    }
}

YdrMaster's avatar
YdrMaster committed
114
void rearrange(
115
116
    void *dst,
    const void *src,
YdrMaster's avatar
YdrMaster committed
117
    const size_t *shape,
118
119
    const ptrdiff_t *dst_strides,
    const ptrdiff_t *src_strides,
YdrMaster's avatar
YdrMaster committed
120
121
122
    size_t ndim,
    size_t element_size) {

123
124
125
    auto scheme = RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
    if (scheme) {
        scheme->launch(dst, src);
YdrMaster's avatar
YdrMaster committed
126
    } else {
127
        std::abort();
YdrMaster's avatar
YdrMaster committed
128
129
130
131
    }
}

} // namespace utils