swizzle.cc 3.29 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file layout/swizzle.cc
 * \brief Define swizzled layout
 *
 */

#include "swizzle.h"

9
#include <tvm/node/node.h>
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <cmath>

namespace tvm {
namespace tl {

SwizzlePattern::SwizzlePattern(int bits, int base, int shift)
    : bits_(bits), base_(base), shift_(shift) {
  ICHECK(bits >= 0);
  ICHECK(base >= 0);
  ICHECK(shift >= 0);
  ICHECK(shift >= bits);
}

PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const {
  int base = (1 << base_);
  int mask = ((1 << bits_) - 1) << shift_;
  PrimExpr high = FloorDiv(expr, base);
  PrimExpr low = FloorMod(expr, base);
  high = bitwise_xor(high, right_shift(bitwise_and(high, mask), shift_));
  return low + high * base;
}

35
36
37
bool SwizzlePattern::operator==(const SwizzlePattern &other) const {
  return std::tie(base_, bits_, shift_) ==
         std::tie(other.base_, other.bits_, other.shift_);
38
39
}

40
41
SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size,
                                       Array<PrimExpr> forward_index,
42
43
44
45
46
                                       SwizzlePattern pattern)
    : pattern_(pattern) {
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
47
48
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
49
50
}

51
Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const {
52
53
54
55
56
57
58
  auto expr_list = LayoutNode::Forward(vars);
  auto expr = expr_list.back();
  expr_list.pop_back();
  expr_list.push_back(pattern_.swizzle(expr));
  return expr_list;
}

59
60
61
62
63
64
std::string SwizzledLayoutNode::DebugOutput() const {
  std::stringstream ss;
  ss << LayoutNode::DebugOutput();
  ss << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " "
     << pattern_.Shift();
  return ss.str();
65
66
67
68
69
70
71
}

Layout SwizzledLayoutNode::Inverse() const {
  ICHECK(0) << "Not Implemented.";
  return {};
}

72
73
74
75
76
bool SwizzledLayoutNode::IsEqual(const SwizzledLayoutNode *other,
                                 bool skip_index) const {
  return LayoutNode::IsEqual(other, skip_index) && pattern_ == other->pattern_;
}

77
78
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
                               Array<PrimExpr> forward_index,
79
80
81
82
83
84
85
86
                               SwizzlePattern pattern) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
87
88
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
89

90
91
  auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
                                                     pattern);
92
93
94
  data_ = std::move(n);
}

95
96
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
                               Array<PrimExpr> forward_index,
97
                               SwizzlePattern pattern) {
98
99
  auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
                                                     pattern);
100
101
102
  data_ = std::move(n);
}

103
104
105
void SwizzledLayoutNode::RegisterReflection() {
  namespace refl = tvm::ffi::reflection;
  refl::ObjectDef<SwizzledLayoutNode>();
106
}
107

108
} // namespace tl
109
} // namespace tvm