rearrange.cc 6.04 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
#include "rearrange.h"
YdrMaster's avatar
YdrMaster committed
2
3
4
5
6
#include "check.h"
#include <algorithm>
#include <cstring>
#include <vector>

7
8
9
10
#ifdef ENABLE_OMP
#include <omp.h>
#endif

YdrMaster's avatar
YdrMaster committed
11
12
namespace utils {

13
14
15
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
    : _meta(std::move(meta)) {}

16
Result<RearrangeMeta> RearrangeMeta::create(
YdrMaster's avatar
YdrMaster committed
17
18
19
20
21
    const size_t *shape,
    const ptrdiff_t *dst_strides_,
    const ptrdiff_t *src_strides_,
    size_t ndim,
    size_t element_size) {
PanZezhong's avatar
PanZezhong committed
22
23
24

    ptrdiff_t unit = element_size;

YdrMaster's avatar
YdrMaster committed
25
26
27
28
29
30
31
32
33
    struct Dim {
        size_t len;
        ptrdiff_t dst, src;
    };

    std::vector<Dim> dims;
    for (size_t i = 0; i < ndim; ++i) {
        // 剔除初始的 1 长维度
        if (shape[i] != 1) {
PanZezhong's avatar
PanZezhong committed
34
            auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit;
35
            if (sd == 0) {
36
                return INFINI_STATUS_BAD_TENSOR_STRIDES;
37
            }
YdrMaster's avatar
YdrMaster committed
38
39
40
41
42
43
44
45
46
47
48
49
50
            dims.push_back(Dim{shape[i], sd, ss});
        }
    }
    // 排序
    std::sort(dims.begin(), dims.end(), [](const Dim &a, const Dim &b) {
        if (std::abs(a.dst) == std::abs(b.dst)) {
            if (std::abs(a.src) == std::abs(b.src)) {
                return a.len < b.len;
            }
            return std::abs(a.src) > std::abs(b.src);
        }
        return std::abs(a.dst) > std::abs(b.dst);
    });
YdrMaster's avatar
YdrMaster committed
51
    ndim = dims.size();
YdrMaster's avatar
YdrMaster committed
52
53
54
55
56
57
58
59
60
61
62
    // # 合并连续维度
    // ## 合并末尾连续维度到 unit
    for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
        if (it->dst == unit && it->src == unit) {
            unit *= it->len;
            ndim -= 1;
        } else {
            break;
        }
    }
    // ## 合并任意连续维度
PanZezhong's avatar
PanZezhong committed
63
    for (ptrdiff_t i = ndim - 1; i > 0; --i) {
YdrMaster's avatar
YdrMaster committed
64
65
66
67
68
69
70
71
72
73
74
        auto &f = dims[i - 1];
        auto &b = dims[i];
        ptrdiff_t len = b.len;
        if (b.dst * len == f.dst && b.src * len == f.src) {
            f = Dim{b.len * f.len, b.dst, b.src};
            b = Dim{1, 0, 0};
            ndim -= 1;
        }
    }
    dims.resize(ndim);
    // 填写序号步长、输入步长和输出步长
75
76
77
    std::vector<ptrdiff_t> meta(2 + ndim * 3);
    meta[0] = unit;
    meta[1 + ndim] = 1;
YdrMaster's avatar
YdrMaster committed
78
    for (size_t i = 0; i < ndim; ++i) {
79
80
81
        meta[1 + i] = dims[i].len;
        meta[1 + 1 + ndim + i] = dims[i].dst;
        meta[1 + 1 + ndim * 2 + i] = dims[i].src;
YdrMaster's avatar
YdrMaster committed
82
    }
83
84
    for (ptrdiff_t i = ndim; i > 0; --i) {
        meta[1 + i - 1] *= meta[1 + i];
YdrMaster's avatar
YdrMaster committed
85
    }
86
    return Result<RearrangeMeta>(meta);
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
}

size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
size_t RearrangeMeta::unit() const { return _meta[0]; }
size_t RearrangeMeta::count() const { return _meta[1]; }

const ptrdiff_t *RearrangeMeta::idx_strides() const { return _meta.data() + 2; }
const ptrdiff_t *RearrangeMeta::dst_strides() const { return idx_strides() + ndim(); }
const ptrdiff_t *RearrangeMeta::src_strides() const { return dst_strides() + ndim(); }

void RearrangeMeta::launch(void *dst_, const void *src_) const {
    auto const ndim_ = ndim();
    auto const count_ = count();
    auto const unit_ = unit();
    auto const idx_strides_ = idx_strides();
    auto const dst_strides_ = dst_strides();
    auto const src_strides_ = src_strides();
YdrMaster's avatar
YdrMaster committed
104
    // 执行 rearrange
105
106
    if (count_ == 1) {
        std::memcpy(dst_, src_, unit_);
YdrMaster's avatar
YdrMaster committed
107
    } else {
108
109
#pragma omp parallel for
        for (ptrdiff_t i = 0; i < (ptrdiff_t)count_; ++i) {
YdrMaster's avatar
YdrMaster committed
110
111
            auto dst = reinterpret_cast<char *>(dst_);
            auto src = reinterpret_cast<const char *>(src_);
112
113
            auto rem = i;
            for (size_t j = 0; j < ndim_; ++j) {
114
                auto k = rem / idx_strides_[j];
115
116
                dst += k * dst_strides_[j];
                src += k * src_strides_[j];
117
                rem %= idx_strides_[j];
YdrMaster's avatar
YdrMaster committed
118
            }
119
            std::memcpy(dst, src, unit_);
YdrMaster's avatar
YdrMaster committed
120
121
122
123
        }
    }
}

YdrMaster's avatar
YdrMaster committed
124
void rearrange(
125
126
    void *dst,
    const void *src,
YdrMaster's avatar
YdrMaster committed
127
    const size_t *shape,
128
129
    const ptrdiff_t *dst_strides,
    const ptrdiff_t *src_strides,
YdrMaster's avatar
YdrMaster committed
130
131
132
    size_t ndim,
    size_t element_size) {

133
134
135
    auto scheme = RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
    if (scheme) {
        scheme->launch(dst, src);
YdrMaster's avatar
YdrMaster committed
136
    } else {
137
        std::abort();
YdrMaster's avatar
YdrMaster committed
138
139
140
    }
}

pwhMass's avatar
pwhMass committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
utils::Result<RearrangeMeta> RearrangeMeta::distributeUnit(const std::vector<size_t> &candidates) const {
    // 获取当前的unit大小
    size_t current_unit = _meta[0];

    // 寻找满足条件的unit值:当前unit能被其整除
    size_t new_unit = 0;
    for (size_t candidate : candidates) {
        if (current_unit % candidate == 0) {
            new_unit = candidate;
            break;
        }
    }

    // 如果没找到合适的值,返回错误
    if (new_unit == 0) {
        return INFINI_STATUS_BAD_PARAM;
    }

    // 如果找到的值就是当前unit,返回自身的副本
    if (new_unit == current_unit) {
        return Result<RearrangeMeta>(_meta);
    }

    // 获取当前维度
    size_t ndim_value = this->ndim();

    // 创建新的布局数组
    std::vector<ptrdiff_t> layout(2 + (ndim_value + 1) * 3, 0);

    // 设置新的unit值
    layout[0] = new_unit;

    // 计算扩展因子
    ptrdiff_t extra = current_unit / new_unit;

    // 计算步长指针的偏移量
    ptrdiff_t idx_offset = 1;

    // 在新布局中设置相应的指针
    ptrdiff_t *new_idx = layout.data() + 1;
    ptrdiff_t *new_dst = layout.data() + 2 + (ndim_value + 1);
    ptrdiff_t *new_src = layout.data() + 2 + (ndim_value + 1) * 2;

    // 复制并调整索引步长

    // 索引步长需要重新计算
    // 首先复制原来的索引步长
    for (size_t i = 0; i < ndim_value + 1; ++i) {
        new_idx[i] = _meta[idx_offset + i] * extra;
    }

    // 设置最后一个维度的步长为1
    new_idx[ndim_value + 1] = 1;

    // 复制目标步长数据,并添加新单元大小
    for (size_t i = 0; i < ndim_value; ++i) {
        new_dst[i] = dst_strides()[i];
    }
    new_dst[ndim_value] = new_unit;

    // 复制源步长数据,并添加新单元大小
    for (size_t i = 0; i < ndim_value; ++i) {
        new_src[i] = src_strides()[i];
    }
    new_src[ndim_value] = new_unit;

    return Result<RearrangeMeta>(layout);
}

YdrMaster's avatar
YdrMaster committed
210
} // namespace utils