legalize_negative_index.cc 7.8 KB
Newer Older
1
2
/*!
 * \file legalize_negative_index.cc
3
 * \brief Legalize negative indices in buffer load/store expressions.
4
5
6
7
 */

#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
8
#include <tvm/tir/op.h>
9
10
11
12
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
13
#include <variant>
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <vector>

#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRVisitorWithAnalyzer;

enum class IndexSignState { kNonNegative, kNegative, kUnknown };

27
28
29
30
31
using BufferAccessVariant =
    std::variant<const BufferLoadNode *, const BufferStoreNode *>;
using LoadStore2StateMap =
    std::unordered_map<BufferAccessVariant, std::vector<IndexSignState>>;

32
33
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
34
  explicit NegativeIndexAnalyzer(LoadStore2StateMap *result)
35
36
      : result_(result) {}

37
38
39
private:
  std::vector<IndexSignState> ProcessIdx(const ffi::Array<PrimExpr> &indices,
                                         ffi::String buffer_name) {
40
    std::vector<IndexSignState> states;
41
    states.reserve(indices.size());
42

43
44
45
    for (size_t i = 0; i < indices.size(); ++i) {
      PrimExpr simplified = analyzer_.Simplify(indices[i]);
      IndexSignState state = IndexSignState::kUnknown;
46

47
48
49
50
51
52
53
54
55
56
57
58
59
      // Handle vector patterns first to avoid querying lanes() on
      // scalable vectors (which is not allowed at compile-time).
      if (const auto *ramp = simplified.as<RampNode>()) {
        // For scalable vectors, we cannot rely on a constant lane count.
        // Use sufficient (but not necessary) conditions:
        // - If base >= 0 and stride >= 0, all lanes are non-negative.
        // - If base < 0 and stride <= 0, all lanes are negative.
        bool base_nonneg = analyzer_.CanProve(ramp->base >= 0);
        bool base_neg = analyzer_.CanProve(ramp->base < 0);
        bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0);
        bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0);

        if (base_nonneg && stride_nonneg) {
60
          state = IndexSignState::kNonNegative;
61
        } else if (base_neg && stride_nonpos) {
62
          state = IndexSignState::kNegative;
63
        } else {
64
65
66
67
          DLOG(WARNING)
              << "LegalizeNegativeIndex: cannot prove non-negative index "
              << simplified << " for buffer " << buffer_name << " (axis " << i
              << ", index " + indices[i]->Script() + ").";
68
        }
69
70
71
72
73
74
75
      } else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
        auto v = analyzer_.Simplify(broadcast->value);
        if (analyzer_.CanProve(v >= 0))
          state = IndexSignState::kNonNegative;
        else if (analyzer_.CanProve(v < 0))
          state = IndexSignState::kNegative;
        else {
76
77
          // Try const bound if proof unavailable
          auto vb = analyzer_.const_int_bound(v);
78
79
80
81
82
83
84
85
86
          if (vb->min_value >= 0)
            state = IndexSignState::kNonNegative;
          else if (vb->max_value < 0)
            state = IndexSignState::kNegative;
          else
            DLOG(WARNING)
                << "LegalizeNegativeIndex: cannot prove non-negative index "
                << simplified << " for buffer " << buffer_name << " (axis " << i
                << ", index " + indices[i]->Script() + ").";
87
        }
88
89
90
91
92
93
94
95
96
97
98
99
100
101
      } else {
        // Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
        // Fall back to scalar reasoning. If this expression is actually a
        // vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
        // Try to prove scalar first; if proof fails, leave as unknown.
        if (analyzer_.CanProve(simplified >= 0))
          state = IndexSignState::kNonNegative;
        else if (analyzer_.CanProve(simplified < 0))
          state = IndexSignState::kNegative;
        else
          DLOG(WARNING)
              << "LegalizeNegativeIndex: cannot prove non-negative index "
              << simplified << " for buffer " << buffer_name << " (axis " << i
              << ", index " + indices[i]->Script() + ").";
102
      }
103
104
      states.push_back(state);
    }
105

106
107
    return std::move(states);
  }
108

109
110
111
112
113
114
115
116
117
118
119
  bool NeedRecord(const std::vector<IndexSignState> &states) {
    return std::any_of(states.begin(), states.end(),
                       [](const IndexSignState &state) {
                         return state == IndexSignState::kUnknown ||
                                state == IndexSignState::kNegative;
                       });
  }

  void VisitExpr_(const BufferLoadNode *op) final {
    std::vector<IndexSignState> states =
        ProcessIdx(op->indices, op->buffer->name);
120

121
    if (NeedRecord(states))
122
123
124
125
126
      (*result_)[op] = std::move(states);

    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

127
128
129
130
131
132
133
134
135
136
  void VisitStmt_(const BufferStoreNode *op) final {
    std::vector<IndexSignState> states =
        ProcessIdx(op->indices, op->buffer->name);

    if (NeedRecord(states))
      (*result_)[op] = std::move(states);

    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

137
private:
138
  LoadStore2StateMap *result_;
139
140
141
142
};

class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
143
  static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) {
144
145
146
147
148
149
150
151
    arith::Analyzer analyzer;
    NegativeIndexRewriter rewriter(&analyzer, states);
    PrimFuncNode *func_node = func.CopyOnWrite();
    func_node->body = rewriter.VisitStmt(func_node->body);
    return func;
  }

private:
152
153
  NegativeIndexRewriter(arith::Analyzer *analyzer,
                        const LoadStore2StateMap &states)
154
155
      : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  ffi::Array<PrimExpr> UpdateIdx(const ffi::Array<PrimExpr> &indices,
                                 const ffi::Array<PrimExpr> &buffer_shape,
                                 const std::vector<IndexSignState> &state_vec) {
    ICHECK_EQ(state_vec.size(), indices.size())
        << "State vector size mismatch for buffer load/store indices ("
        << indices << ")";
    ffi::Array<PrimExpr> new_indices = indices;
    for (size_t i = 0; i < indices.size(); ++i) {
      if (state_vec[i] != IndexSignState::kNegative)
        continue;
      new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i]));
    }
    return new_indices;
  }

171
172
173
174
175
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
    BufferLoad load =
        Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));

    auto it = states_.find(op);
176
    if (it == states_.end())
177
178
      return load;

179
180
181
    auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second);
    return BufferLoad(load->buffer, indices, load->predicate);
  }
182

183
184
185
  Stmt VisitStmt_(const BufferStoreNode *op) final {
    BufferStore store =
        Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
186

187
188
189
    auto it = states_.find(op);
    if (it == states_.end())
      return store;
190

191
192
    auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second);
    return BufferStore(store->buffer, store->value, indices, store->predicate);
193
194
  }

195
196
private:
  const LoadStore2StateMap &states_;
197
198
199
200
201
202
203
};

PrimFunc LegalizeNegativeIndex(PrimFunc func) {
  if (!func->body.defined()) {
    return func;
  }

204
  LoadStore2StateMap states;
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
  NegativeIndexAnalyzer analyzer(&states);
  analyzer(func->body);
  if (states.empty()) {
    return func;
  }

  return NegativeIndexRewriter::Apply(std::move(func), states);
}

tvm::transform::Pass LegalizeNegativeIndexPass() {
  using namespace tir::transform;
  auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
    return LegalizeNegativeIndex(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
                        LegalizeNegativeIndexPass);
}

} // namespace tl
} // namespace tvm