Unverified Commit 12046d02 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #81 from YdrMaster/issue/78

issue/78/feat: 分离 rearrange 的规划和执行以供算子复用 
parents e9b1a513 ba9b0d83
......@@ -6,15 +6,15 @@
namespace utils {
void rearrange(
void *dst_,
const void *src_,
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
: _meta(std::move(meta)) {}
std::optional<RearrangeMeta> RearrangeMeta::create(
const size_t *shape,
const ptrdiff_t *dst_strides_,
const ptrdiff_t *src_strides_,
size_t ndim,
size_t element_size) {
struct Dim {
size_t len;
ptrdiff_t dst, src;
......@@ -63,35 +63,69 @@ void rearrange(
}
dims.resize(ndim);
// 填写序号步长、输入步长和输出步长
std::vector<ptrdiff_t>
idx_strides(ndim + 1),
dst_strides(ndim),
src_strides(ndim);
idx_strides[ndim] = 1;
std::vector<ptrdiff_t> meta(2 + ndim * 3);
meta[0] = unit;
meta[1 + ndim] = 1;
for (size_t i = 0; i < ndim; ++i) {
idx_strides[i] = dims[i].len;
dst_strides[i] = dims[i].dst;
src_strides[i] = dims[i].src;
meta[1 + i] = dims[i].len;
meta[1 + 1 + ndim + i] = dims[i].dst;
meta[1 + 1 + ndim * 2 + i] = dims[i].src;
}
for (size_t i = ndim; i > 0; --i) {
idx_strides[i - 1] *= idx_strides[i];
for (ptrdiff_t i = ndim; i > 0; --i) {
meta[1 + i - 1] *= meta[1 + i];
}
return {RearrangeMeta(std::move(meta))};
}
size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
size_t RearrangeMeta::unit() const { return _meta[0]; }
size_t RearrangeMeta::count() const { return _meta[1]; }
const ptrdiff_t *RearrangeMeta::idx_strides() const { return _meta.data() + 2; }
const ptrdiff_t *RearrangeMeta::dst_strides() const { return idx_strides() + ndim(); }
const ptrdiff_t *RearrangeMeta::src_strides() const { return dst_strides() + ndim(); }
void RearrangeMeta::launch(void *dst_, const void *src_) const {
auto const ndim_ = ndim();
auto const count_ = count();
auto const unit_ = unit();
auto const idx_strides_ = idx_strides();
auto const dst_strides_ = dst_strides();
auto const src_strides_ = src_strides();
// 执行 rearrange
if (idx_strides[0] == 1) {
std::memcpy(dst_, src_, unit);
if (count_ == 1) {
std::memcpy(dst_, src_, unit_);
} else {
for (size_t i = 0; i < idx_strides[0]; ++i) {
for (size_t i = 0; i < idx_strides_[0]; ++i) {
auto dst = reinterpret_cast<char *>(dst_);
auto src = reinterpret_cast<const char *>(src_);
for (size_t j = 0; j < ndim; ++j) {
auto k = i / idx_strides[j + 1];
dst += k * dst_strides[j];
src += k * src_strides[j];
i %= idx_strides[j + 1];
auto rem = i;
for (size_t j = 0; j < ndim_; ++j) {
auto k = rem / idx_strides_[j + 1];
dst += k * dst_strides_[j];
src += k * src_strides_[j];
rem %= idx_strides_[j + 1];
}
std::memcpy(dst, src, unit);
std::memcpy(dst, src, unit_);
}
}
}
void rearrange(
void *dst,
const void *src,
const size_t *shape,
const ptrdiff_t *dst_strides,
const ptrdiff_t *src_strides,
size_t ndim,
size_t element_size) {
auto scheme = RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (scheme) {
scheme->launch(dst, src);
} else {
std::abort();
}
}
} // namespace utils
#ifndef __INFINIUTILS_REARRANGE_H__
#define __INFINIUTILS_REARRANGE_H__
#include <optional>
#include <stddef.h>
#include <vector>
namespace utils {
class RearrangeMeta {
std::vector<ptrdiff_t> _meta;
RearrangeMeta(std::vector<ptrdiff_t>);
public:
static std::optional<RearrangeMeta> create(
const size_t *shape,
const ptrdiff_t *dst_strides,
const ptrdiff_t *src_strides,
size_t ndim,
size_t element_size);
size_t ndim() const;
size_t unit() const;
size_t count() const;
const ptrdiff_t *idx_strides() const;
const ptrdiff_t *dst_strides() const;
const ptrdiff_t *src_strides() const;
void launch(void *dst, const void *src) const;
};
void rearrange(
void *dst,
const void *src,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment