legalize_negative_index.cc 7.96 KB
Newer Older
1
2
/*!
 * \file legalize_negative_index.cc
3
 * \brief Legalize negative indices in buffer load/store expressions.
4
5
6
7
 */

#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
8
#include <tvm/tir/op.h>
9
10
11
12
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
13
#include <variant>
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <vector>

#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRVisitorWithAnalyzer;

enum class IndexSignState { kNonNegative, kNegative, kUnknown };

27
28
29
30
31
using BufferAccessVariant =
    std::variant<const BufferLoadNode *, const BufferStoreNode *>;
using LoadStore2StateMap =
    std::unordered_map<BufferAccessVariant, std::vector<IndexSignState>>;

32
33
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
34
  explicit NegativeIndexAnalyzer(LoadStore2StateMap *result)
35
36
      : result_(result) {}

37
38
39
private:
  std::vector<IndexSignState> ProcessIdx(const ffi::Array<PrimExpr> &indices,
                                         ffi::String buffer_name) {
40
    std::vector<IndexSignState> states;
41
    states.reserve(indices.size());
42

43
44
45
    for (size_t i = 0; i < indices.size(); ++i) {
      PrimExpr simplified = analyzer_.Simplify(indices[i]);
      IndexSignState state = IndexSignState::kUnknown;
46
47
48

      // Handle scalar indices with the standard analyzer
      if (simplified.dtype().lanes() == 1) {
49
50
51
52
53
54
55
56
57
        if (analyzer_.CanProve(simplified >= 0))
          state = IndexSignState::kNonNegative;
        else if (analyzer_.CanProve(simplified < 0))
          state = IndexSignState::kNegative;
        else
          DLOG(WARNING)
              << "LegalizeNegativeIndex: cannot prove non-negative index "
              << simplified << " for buffer " << buffer_name << " (axis " << i
              << ", index " + indices[i]->Script() + ").";
58
      }
59
60
61
      // Vector indices: try to reason about non-negativity/negativity
      // Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
      // lanes).
62
      else if (const auto *ramp = simplified.as<RampNode>()) {
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        // Compute a safe lower/upper bound for the vector lanes
        // lower_bound = base_min + min(0, stride_min) * (lanes - 1)
        // upper_bound = base_max + max(0, stride_max) * (lanes - 1)
        auto base_bound = analyzer_.const_int_bound(ramp->base);
        auto stride_bound = analyzer_.const_int_bound(ramp->stride);
        int lanes = *as_const_int(ramp->lanes);

        int64_t base_min = base_bound->min_value;
        int64_t base_max = base_bound->max_value;
        int64_t s_min = stride_bound->min_value;
        int64_t s_max = stride_bound->max_value;

        // Guard against overflow is not strictly necessary here because
        // bounds may be +/-inf represented by sentinel values.
        int64_t lower = base_min;
        if (s_min < 0)
          lower += s_min * (lanes - 1);
        int64_t upper = base_max;
        if (s_max > 0)
          upper += s_max * (lanes - 1);

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        if (lower >= 0)
          state = IndexSignState::kNonNegative;
        else if (upper < 0)
          state = IndexSignState::kNegative;
        else
          DLOG(WARNING)
              << "LegalizeNegativeIndex: cannot prove non-negative index "
              << simplified << " for buffer " << buffer_name << " (axis " << i
              << ", index " + indices[i]->Script() + ").";
      } else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
        auto v = analyzer_.Simplify(broadcast->value);
        if (analyzer_.CanProve(v >= 0))
          state = IndexSignState::kNonNegative;
        else if (analyzer_.CanProve(v < 0))
          state = IndexSignState::kNegative;
        else {
100
101
          // Try const bound if proof unavailable
          auto vb = analyzer_.const_int_bound(v);
102
103
104
105
106
107
108
109
110
          if (vb->min_value >= 0)
            state = IndexSignState::kNonNegative;
          else if (vb->max_value < 0)
            state = IndexSignState::kNegative;
          else
            DLOG(WARNING)
                << "LegalizeNegativeIndex: cannot prove non-negative index "
                << simplified << " for buffer " << buffer_name << " (axis " << i
                << ", index " + indices[i]->Script() + ").";
111
112
        }
      }
113
114
      states.push_back(state);
    }
115

116
117
    return std::move(states);
  }
118

119
120
121
122
123
124
125
126
127
128
129
  bool NeedRecord(const std::vector<IndexSignState> &states) {
    return std::any_of(states.begin(), states.end(),
                       [](const IndexSignState &state) {
                         return state == IndexSignState::kUnknown ||
                                state == IndexSignState::kNegative;
                       });
  }

  void VisitExpr_(const BufferLoadNode *op) final {
    std::vector<IndexSignState> states =
        ProcessIdx(op->indices, op->buffer->name);
130

131
    if (NeedRecord(states))
132
133
134
135
136
      (*result_)[op] = std::move(states);

    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

137
138
139
140
141
142
143
144
145
146
  void VisitStmt_(const BufferStoreNode *op) final {
    std::vector<IndexSignState> states =
        ProcessIdx(op->indices, op->buffer->name);

    if (NeedRecord(states))
      (*result_)[op] = std::move(states);

    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

147
private:
148
  LoadStore2StateMap *result_;
149
150
151
152
};

class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
153
  static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) {
154
155
156
157
158
159
160
161
    arith::Analyzer analyzer;
    NegativeIndexRewriter rewriter(&analyzer, states);
    PrimFuncNode *func_node = func.CopyOnWrite();
    func_node->body = rewriter.VisitStmt(func_node->body);
    return func;
  }

private:
162
163
  NegativeIndexRewriter(arith::Analyzer *analyzer,
                        const LoadStore2StateMap &states)
164
165
      : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  ffi::Array<PrimExpr> UpdateIdx(const ffi::Array<PrimExpr> &indices,
                                 const ffi::Array<PrimExpr> &buffer_shape,
                                 const std::vector<IndexSignState> &state_vec) {
    ICHECK_EQ(state_vec.size(), indices.size())
        << "State vector size mismatch for buffer load/store indices ("
        << indices << ")";
    ffi::Array<PrimExpr> new_indices = indices;
    for (size_t i = 0; i < indices.size(); ++i) {
      if (state_vec[i] != IndexSignState::kNegative)
        continue;
      new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i]));
    }
    return new_indices;
  }

181
182
183
184
185
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
    BufferLoad load =
        Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));

    auto it = states_.find(op);
186
    if (it == states_.end())
187
188
      return load;

189
190
191
    auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second);
    return BufferLoad(load->buffer, indices, load->predicate);
  }
192

193
194
195
  Stmt VisitStmt_(const BufferStoreNode *op) final {
    BufferStore store =
        Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
196

197
198
199
    auto it = states_.find(op);
    if (it == states_.end())
      return store;
200

201
202
    auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second);
    return BufferStore(store->buffer, store->value, indices, store->predicate);
203
204
  }

205
206
private:
  const LoadStore2StateMap &states_;
207
208
209
210
211
212
213
};

PrimFunc LegalizeNegativeIndex(PrimFunc func) {
  if (!func->body.defined()) {
    return func;
  }

214
  LoadStore2StateMap states;
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
  NegativeIndexAnalyzer analyzer(&states);
  analyzer(func->body);
  if (states.empty()) {
    return func;
  }

  return NegativeIndexRewriter::Apply(std::move(func), states);
}

tvm::transform::Pass LegalizeNegativeIndexPass() {
  using namespace tir::transform;
  auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
    return LegalizeNegativeIndex(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
                        LegalizeNegativeIndexPass);
}

} // namespace tl
} // namespace tvm