"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "1682b223c2ed35c703a8e27bc20396c12aa13eda"
swizzle.cc 3.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/*!
 * \file layout/swizzle.cc
 * \brief Define swizzled layout
 *
 */

#include "swizzle.h"

#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <cmath>

namespace tvm {
namespace tl {

SwizzlePattern::SwizzlePattern(int bits, int base, int shift)
    : bits_(bits), base_(base), shift_(shift) {
  ICHECK(bits >= 0);
  ICHECK(base >= 0);
  ICHECK(shift >= 0);
  ICHECK(shift >= bits);
}

PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const {
  int base = (1 << base_);
  int mask = ((1 << bits_) - 1) << shift_;
  PrimExpr high = FloorDiv(expr, base);
  PrimExpr low = FloorMod(expr, base);
  high = bitwise_xor(high, right_shift(bitwise_and(high, mask), shift_));
  return low + high * base;
}

34
35
36
bool SwizzlePattern::operator==(const SwizzlePattern &other) const {
  return std::tie(base_, bits_, shift_) ==
         std::tie(other.base_, other.bits_, other.shift_);
37
38
}

39
40
SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size,
                                       Array<PrimExpr> forward_index,
41
42
43
44
45
                                       SwizzlePattern pattern)
    : pattern_(pattern) {
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
46
47
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
48
49
}

50
Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const {
51
52
53
54
55
56
57
  auto expr_list = LayoutNode::Forward(vars);
  auto expr = expr_list.back();
  expr_list.pop_back();
  expr_list.push_back(pattern_.swizzle(expr));
  return expr_list;
}

58
59
60
61
62
63
std::string SwizzledLayoutNode::DebugOutput() const {
  std::stringstream ss;
  ss << LayoutNode::DebugOutput();
  ss << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " "
     << pattern_.Shift();
  return ss.str();
64
65
66
67
68
69
70
}

Layout SwizzledLayoutNode::Inverse() const {
  ICHECK(0) << "Not Implemented.";
  return {};
}

71
72
73
74
75
bool SwizzledLayoutNode::IsEqual(const SwizzledLayoutNode *other,
                                 bool skip_index) const {
  return LayoutNode::IsEqual(other, skip_index) && pattern_ == other->pattern_;
}

76
77
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
                               Array<PrimExpr> forward_index,
78
79
80
81
82
83
84
85
                               SwizzlePattern pattern) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
86
87
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
88
89
90
91
92

  auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
  data_ = std::move(n);
}

93
94
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
                               Array<PrimExpr> forward_index,
95
96
97
98
99
                               SwizzlePattern pattern) {
  auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
  data_ = std::move(n);
}

100
101
102
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor *v) {
  LayoutNode::VisitAttrs(v);
}
103

104
105
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
                                      SEqualReducer equal) const {
106
  return equal(this->InputShape(), other->InputShape()) &&
107
108
         equal(this->forward_index_, other->forward_index_) &&
         pattern_ == other->pattern_;
109
110
111
112
}

TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);

113
114
} // namespace tl
} // namespace tvm