"src/vscode:/vscode.git/clone" did not exist on "775dfe8a993be99f827fdb5b9d1ffa07bb11ec99"
config_index_bitwidth.cc 5.98 KB
Newer Older
1
#include "../op/builtin.h"
2
#include "arith/ir_mutator_with_analyzer.h"
3
4
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
5
6
7
8
9
10
11
12
13
#include <tvm/tir/builtin.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tl {

using namespace tir;
14
using namespace arith;
15
16
17
18
19
20
class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
public:
  using Parent = IndexDataTypeRewriter;
  ConfigIndexBitwidthRewriter(int index_bitwidth)
      : _index_bitwidth_(index_bitwidth) {}

21
  Stmt operator()(const Stmt &s) { return VisitStmt(s); }
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

protected:
  using Parent::VisitExpr_;
  using Parent::VisitStmt_;

  PrimExpr VisitExpr_(const VarNode *op) final {
    if (op->dtype.is_int() && op->dtype.bits() < 64) {
      DataType new_dtype = DataType::Int(64);
      if (!var_remap_.count(op)) {
        var_remap_[op] = Var(op->name_hint, new_dtype);
      }
    }
    return Parent::VisitExpr_(op);
  }

  PrimExpr VisitExpr_(const IntImmNode *op) final {
    if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
      return IntImm(DataType::Int(_index_bitwidth_), op->value);
    }
    return GetRef<PrimExpr>(op);
  }

  PrimExpr VisitExpr_(const CastNode *op) final {
    if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
      PrimExpr value = VisitExpr(op->value);
      return Cast(DataType::Int(_index_bitwidth_), value);
    }
    return Parent::VisitExpr_(op);
  }

  Stmt VisitStmt_(const BufferStoreNode *op) final {
    // Force indices to be int64
    bool is_enabled = is_enabled_;
    is_enabled_ = true;
    auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
    is_enabled_ = is_enabled;
    return std::move(node);
  }

  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
    // Force indices to be int64
    bool is_enabled = is_enabled_;
    is_enabled_ = true;
    auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
    is_enabled_ = is_enabled;
    return std::move(node);
  }

  int _index_bitwidth_;
};

73
74
75
class IndexLegalizer : public IRMutatorWithAnalyzer {

public:
76
  static Stmt Rewrite(const Stmt &stmt) {
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    Analyzer ana;
    auto pass = IndexLegalizer(&ana);
    return pass.VisitStmt(stmt);
  }

private:
  explicit IndexLegalizer(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}

  class Int64Promoter : public IndexDataTypeRewriter {
  public:
    using Parent = IndexDataTypeRewriter;

    PrimExpr VisitExpr_(const VarNode *op) final {
      if (op->dtype.is_int() && op->dtype.bits() < 64) {
        return cast(DataType::Int(64), GetRef<Var>(op));
      }
      return GetRef<PrimExpr>(op);
    }

    PrimExpr VisitExpr_(const IntImmNode *op) final {
      if (op->dtype.is_int() && op->dtype.bits() < 64) {
        return IntImm(DataType::Int(64), op->value);
      }
      return GetRef<PrimExpr>(op);
    }

    PrimExpr VisitExpr_(const CastNode *op) final {
      if (op->dtype.is_int() && op->dtype.bits() < 64) {
        return cast(DataType::Int(64), op->value);
      }
      return GetRef<PrimExpr>(op);
    }

    Stmt VisitStmt_(const BufferStoreNode *op) final {
      // Force indices to be int64
      auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
      return std::move(node);
    }

    PrimExpr VisitExpr_(const BufferLoadNode *op) final {
      auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
      return std::move(node);
    }
  };

  Stmt VisitStmt_(const BufferStoreNode *op) final {
    auto buffer_store =
        Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
    auto indices = buffer_store->indices;
126
    Array<PrimExpr> new_indices;
127
128
129
130
131
132
133
    for (auto index : indices) {
      if (index->dtype.is_int() && index->dtype.bits() < 64) {
        auto int_bound = analyzer_->const_int_bound(index);
        if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
            int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
          Int64Promoter promoter;
          index = promoter(index);
134
135
          new_indices.push_back(index);
          continue;
136
137
        }
      }
138
      new_indices.push_back(index);
139
    }
140
    buffer_store.CopyOnWrite()->indices = new_indices;
141
142
143
144
145
146
147
    return std::move(buffer_store);
  }

  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
    auto buffer_load =
        Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
    auto indices = buffer_load->indices;
148
    Array<PrimExpr> new_indices;
149
150
151
152
153
154
155
    for (auto index : indices) {
      if (index->dtype.is_int() && index->dtype.bits() < 64) {
        auto int_bound = analyzer_->const_int_bound(index);
        if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
            int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
          Int64Promoter promoter;
          index = promoter(index);
156
157
          new_indices.push_back(index);
          continue;
158
159
        }
      }
160
      new_indices.push_back(index);
161
    }
162
    buffer_load.CopyOnWrite()->indices = new_indices;
163
164
165
166
    return std::move(buffer_load);
  }
};

167
168
tvm::transform::Pass ConfigIndexBitwidth() {
  using namespace tir::transform;
169
  auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
170
171
172
173
174
175
176
    auto *n = f.CopyOnWrite();
    // Get pass config `tl.config_index_bitwidth`
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Integer> opt_config_index_bitwidth =
        ctxt->GetConfig(kConfigIndexBitwidth, Optional<Integer>());
    if (opt_config_index_bitwidth.defined()) {
      int config_index_bitwidth = opt_config_index_bitwidth.value()->value;
177
      n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(n->body);
178
    }
179
    // Legalize out-of-bound indices to be int64
180
    n->body = IndexLegalizer::Rewrite(n->body);
181
182
183
184
185
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
}

186
187
188
189
190
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
                        ConfigIndexBitwidth);
});
191
192
193

} // namespace tl
} // namespace tvm