Unverified Commit 12046d02 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #81 from YdrMaster/issue/78

issue/78/feat: 分离 rearrange 的规划和执行以供算子复用 
parents e9b1a513 ba9b0d83
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
namespace utils { namespace utils {
void rearrange( RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
void *dst_, : _meta(std::move(meta)) {}
const void *src_,
std::optional<RearrangeMeta> RearrangeMeta::create(
const size_t *shape, const size_t *shape,
const ptrdiff_t *dst_strides_, const ptrdiff_t *dst_strides_,
const ptrdiff_t *src_strides_, const ptrdiff_t *src_strides_,
size_t ndim, size_t ndim,
size_t element_size) { size_t element_size) {
struct Dim { struct Dim {
size_t len; size_t len;
ptrdiff_t dst, src; ptrdiff_t dst, src;
...@@ -63,35 +63,69 @@ void rearrange( ...@@ -63,35 +63,69 @@ void rearrange(
} }
dims.resize(ndim); dims.resize(ndim);
// 填写序号步长、输入步长和输出步长 // 填写序号步长、输入步长和输出步长
std::vector<ptrdiff_t> std::vector<ptrdiff_t> meta(2 + ndim * 3);
idx_strides(ndim + 1), meta[0] = unit;
dst_strides(ndim), meta[1 + ndim] = 1;
src_strides(ndim);
idx_strides[ndim] = 1;
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
idx_strides[i] = dims[i].len; meta[1 + i] = dims[i].len;
dst_strides[i] = dims[i].dst; meta[1 + 1 + ndim + i] = dims[i].dst;
src_strides[i] = dims[i].src; meta[1 + 1 + ndim * 2 + i] = dims[i].src;
} }
for (size_t i = ndim; i > 0; --i) { for (ptrdiff_t i = ndim; i > 0; --i) {
idx_strides[i - 1] *= idx_strides[i]; meta[1 + i - 1] *= meta[1 + i];
} }
return {RearrangeMeta(std::move(meta))};
}
size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
size_t RearrangeMeta::unit() const { return _meta[0]; }
size_t RearrangeMeta::count() const { return _meta[1]; }
const ptrdiff_t *RearrangeMeta::idx_strides() const { return _meta.data() + 2; }
const ptrdiff_t *RearrangeMeta::dst_strides() const { return idx_strides() + ndim(); }
const ptrdiff_t *RearrangeMeta::src_strides() const { return dst_strides() + ndim(); }
void RearrangeMeta::launch(void *dst_, const void *src_) const {
auto const ndim_ = ndim();
auto const count_ = count();
auto const unit_ = unit();
auto const idx_strides_ = idx_strides();
auto const dst_strides_ = dst_strides();
auto const src_strides_ = src_strides();
// 执行 rearrange // 执行 rearrange
if (idx_strides[0] == 1) { if (count_ == 1) {
std::memcpy(dst_, src_, unit); std::memcpy(dst_, src_, unit_);
} else { } else {
for (size_t i = 0; i < idx_strides[0]; ++i) { for (size_t i = 0; i < idx_strides_[0]; ++i) {
auto dst = reinterpret_cast<char *>(dst_); auto dst = reinterpret_cast<char *>(dst_);
auto src = reinterpret_cast<const char *>(src_); auto src = reinterpret_cast<const char *>(src_);
for (size_t j = 0; j < ndim; ++j) { auto rem = i;
auto k = i / idx_strides[j + 1]; for (size_t j = 0; j < ndim_; ++j) {
dst += k * dst_strides[j]; auto k = rem / idx_strides_[j + 1];
src += k * src_strides[j]; dst += k * dst_strides_[j];
i %= idx_strides[j + 1]; src += k * src_strides_[j];
rem %= idx_strides_[j + 1];
} }
std::memcpy(dst, src, unit); std::memcpy(dst, src, unit_);
} }
} }
} }
void rearrange(
void *dst,
const void *src,
const size_t *shape,
const ptrdiff_t *dst_strides,
const ptrdiff_t *src_strides,
size_t ndim,
size_t element_size) {
auto scheme = RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (scheme) {
scheme->launch(dst, src);
} else {
std::abort();
}
}
} // namespace utils } // namespace utils
#ifndef __INFINIUTILS_REARRANGE_H__ #ifndef __INFINIUTILS_REARRANGE_H__
#define __INFINIUTILS_REARRANGE_H__ #define __INFINIUTILS_REARRANGE_H__
#include <optional>
#include <stddef.h> #include <stddef.h>
#include <vector>
namespace utils { namespace utils {
class RearrangeMeta {
std::vector<ptrdiff_t> _meta;
RearrangeMeta(std::vector<ptrdiff_t>);
public:
static std::optional<RearrangeMeta> create(
const size_t *shape,
const ptrdiff_t *dst_strides,
const ptrdiff_t *src_strides,
size_t ndim,
size_t element_size);
size_t ndim() const;
size_t unit() const;
size_t count() const;
const ptrdiff_t *idx_strides() const;
const ptrdiff_t *dst_strides() const;
const ptrdiff_t *src_strides() const;
void launch(void *dst, const void *src) const;
};
void rearrange( void rearrange(
void *dst, void *dst,
const void *src, const void *src,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment