swizzle.cc 3.47 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*!
 * \file layout/swizzle.cc
 * \brief Define swizzled layout
 *
 */

#include "swizzle.h"

#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <cmath>

namespace tvm {
namespace tl {

SwizzlePattern::SwizzlePattern(int bits, int base, int shift)
    : bits_(bits), base_(base), shift_(shift) {
  ICHECK(bits >= 0);
  ICHECK(base >= 0);
  ICHECK(shift >= 0);
  ICHECK(shift >= bits);
}

PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const {
  int base = (1 << base_);
  int mask = ((1 << bits_) - 1) << shift_;
  PrimExpr high = FloorDiv(expr, base);
  PrimExpr low = FloorMod(expr, base);
  high = bitwise_xor(high, right_shift(bitwise_and(high, mask), shift_));
  return low + high * base;
}

bool SwizzlePattern::operator==(const SwizzlePattern &other) const {
  return std::tie(base_, bits_, shift_) ==
         std::tie(other.base_, other.bits_, other.shift_);
}

SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size,
                                       Array<PrimExpr> forward_index,
                                       SwizzlePattern pattern)
    : pattern_(pattern) {
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
}

Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const {
  auto expr_list = LayoutNode::Forward(vars);
  auto expr = expr_list.back();
  expr_list.pop_back();
  expr_list.push_back(pattern_.swizzle(expr));
  return expr_list;
}

std::string SwizzledLayoutNode::DebugOutput() const {
  std::stringstream ss;
  ss << LayoutNode::DebugOutput();
  ss << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " "
     << pattern_.Shift();
  return ss.str();
}

Layout SwizzledLayoutNode::Inverse() const {
  ICHECK(0) << "Not Implemented.";
  return {};
}

bool SwizzledLayoutNode::IsEqual(const SwizzledLayoutNode *other,
                                 bool skip_index) const {
  return LayoutNode::IsEqual(other, skip_index) && pattern_ == other->pattern_;
}

SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
                               Array<PrimExpr> forward_index,
                               SwizzlePattern pattern) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });

  auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
  data_ = std::move(n);
}

SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
                               Array<PrimExpr> forward_index,
                               SwizzlePattern pattern) {
  auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
  data_ = std::move(n);
}

void SwizzledLayoutNode::RegisterReflection() {
  namespace refl = tvm::ffi::reflection;
  refl::ObjectDef<SwizzledLayoutNode>();
}

bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
                                      SEqualReducer equal) const {
  return equal(this->InputShape(), other->InputShape()) &&
         equal(this->forward_index_, other->forward_index_) &&
         pattern_ == other->pattern_;
}

TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);

} // namespace tl
} // namespace tvm