rearrange.cc 4.02 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
#include "rearrange.h"
YdrMaster's avatar
YdrMaster committed
2
3
4
5
6
#include "check.h"
#include <algorithm>
#include <cstring>
#include <vector>

7
8
9
10
#ifdef ENABLE_OMP
#include <omp.h>
#endif

YdrMaster's avatar
YdrMaster committed
11
12
namespace utils {

13
14
15
16
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
    : _meta(std::move(meta)) {}

std::optional<RearrangeMeta> RearrangeMeta::create(
YdrMaster's avatar
YdrMaster committed
17
18
19
20
21
    const size_t *shape,
    const ptrdiff_t *dst_strides_,
    const ptrdiff_t *src_strides_,
    size_t ndim,
    size_t element_size) {
PanZezhong's avatar
PanZezhong committed
22
23
24

    ptrdiff_t unit = element_size;

YdrMaster's avatar
YdrMaster committed
25
26
27
28
29
30
31
32
33
    struct Dim {
        size_t len;
        ptrdiff_t dst, src;
    };

    std::vector<Dim> dims;
    for (size_t i = 0; i < ndim; ++i) {
        // 剔除初始的 1 长维度
        if (shape[i] != 1) {
PanZezhong's avatar
PanZezhong committed
34
            auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit;
YdrMaster's avatar
YdrMaster committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
            // assert (sd != 0)
            dims.push_back(Dim{shape[i], sd, ss});
        }
    }
    // 排序
    std::sort(dims.begin(), dims.end(), [](const Dim &a, const Dim &b) {
        if (std::abs(a.dst) == std::abs(b.dst)) {
            if (std::abs(a.src) == std::abs(b.src)) {
                return a.len < b.len;
            }
            return std::abs(a.src) > std::abs(b.src);
        }
        return std::abs(a.dst) > std::abs(b.dst);
    });
    // # 合并连续维度
    // ## 合并末尾连续维度到 unit
    for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
        if (it->dst == unit && it->src == unit) {
            unit *= it->len;
            ndim -= 1;
        } else {
            break;
        }
    }
    // ## 合并任意连续维度
PanZezhong's avatar
PanZezhong committed
60
    for (ptrdiff_t i = ndim - 1; i > 0; --i) {
YdrMaster's avatar
YdrMaster committed
61
62
63
64
65
66
67
68
69
70
71
        auto &f = dims[i - 1];
        auto &b = dims[i];
        ptrdiff_t len = b.len;
        if (b.dst * len == f.dst && b.src * len == f.src) {
            f = Dim{b.len * f.len, b.dst, b.src};
            b = Dim{1, 0, 0};
            ndim -= 1;
        }
    }
    dims.resize(ndim);
    // 填写序号步长、输入步长和输出步长
72
73
74
    std::vector<ptrdiff_t> meta(2 + ndim * 3);
    meta[0] = unit;
    meta[1 + ndim] = 1;
YdrMaster's avatar
YdrMaster committed
75
    for (size_t i = 0; i < ndim; ++i) {
76
77
78
        meta[1 + i] = dims[i].len;
        meta[1 + 1 + ndim + i] = dims[i].dst;
        meta[1 + 1 + ndim * 2 + i] = dims[i].src;
YdrMaster's avatar
YdrMaster committed
79
    }
80
81
    for (ptrdiff_t i = ndim; i > 0; --i) {
        meta[1 + i - 1] *= meta[1 + i];
YdrMaster's avatar
YdrMaster committed
82
    }
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    return {RearrangeMeta(std::move(meta))};
}

size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
size_t RearrangeMeta::unit() const { return _meta[0]; }
size_t RearrangeMeta::count() const { return _meta[1]; }

const ptrdiff_t *RearrangeMeta::idx_strides() const { return _meta.data() + 2; }
const ptrdiff_t *RearrangeMeta::dst_strides() const { return idx_strides() + ndim(); }
const ptrdiff_t *RearrangeMeta::src_strides() const { return dst_strides() + ndim(); }

void RearrangeMeta::launch(void *dst_, const void *src_) const {
    auto const ndim_ = ndim();
    auto const count_ = count();
    auto const unit_ = unit();
    auto const idx_strides_ = idx_strides();
    auto const dst_strides_ = dst_strides();
    auto const src_strides_ = src_strides();
YdrMaster's avatar
YdrMaster committed
101
    // 执行 rearrange
102
103
    if (count_ == 1) {
        std::memcpy(dst_, src_, unit_);
YdrMaster's avatar
YdrMaster committed
104
    } else {
105
106
#pragma omp parallel for
        for (ptrdiff_t i = 0; i < (ptrdiff_t)count_; ++i) {
YdrMaster's avatar
YdrMaster committed
107
108
            auto dst = reinterpret_cast<char *>(dst_);
            auto src = reinterpret_cast<const char *>(src_);
109
110
            auto rem = i;
            for (size_t j = 0; j < ndim_; ++j) {
111
                auto k = rem / idx_strides_[j];
112
113
                dst += k * dst_strides_[j];
                src += k * src_strides_[j];
114
                rem %= idx_strides_[j];
YdrMaster's avatar
YdrMaster committed
115
            }
116
            std::memcpy(dst, src, unit_);
YdrMaster's avatar
YdrMaster committed
117
118
119
120
        }
    }
}

YdrMaster's avatar
YdrMaster committed
121
void rearrange(
122
123
    void *dst,
    const void *src,
YdrMaster's avatar
YdrMaster committed
124
    const size_t *shape,
125
126
    const ptrdiff_t *dst_strides,
    const ptrdiff_t *src_strides,
YdrMaster's avatar
YdrMaster committed
127
128
129
    size_t ndim,
    size_t element_size) {

130
131
132
    auto scheme = RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
    if (scheme) {
        scheme->launch(dst, src);
YdrMaster's avatar
YdrMaster committed
133
    } else {
134
        std::abort();
YdrMaster's avatar
YdrMaster committed
135
136
137
138
    }
}

} // namespace utils