legalize_safe_memory_access.cc 11 KB
Newer Older
1
/*!
2
3
 * \file legalize_safe_memory_access.cc
 * \brief legalize safe memory access
4
5
 */

6
#include <tvm/ffi/reflection/registry.h>
7
8
9
10
11
12
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

13
14
#include <utility>

15
#include "../op/builtin.h"
16
#include "../op/parallel.h"
17
#include "arith/ir_mutator_with_analyzer.h"
18
19
20
21
22
23
24
25
26
#include "loop_partition.h"
#include "loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;

27
// GlobalMemChecker for a BufferLoad/BufferStore node:
28
29
30
31
32
33
34
// 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope.
// 3. For each index, compare against the buffer's shape.
//    If the index might exceed the shape (upper bound too large),
//    log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {

35
36
37
  GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds)
      : analyzer_(analyzer),
        recursively_collect_conds_(recursively_collect_conds) {}
38
  void VisitExpr_(const BufferLoadNode *op) final {
39
    // Check if the buffer is in global scope
40
41
    // This is because we are writing TilePrograms, where out of bounds
    // accesses only happen in the global buffer.
42
43
44
    if (IsGlobalBuffer(op->buffer)) {
      CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
    }
45
46
47
    if (recursively_collect_conds_) {
      StmtExprVisitor::VisitExpr_(op);
    }
48
49
  }

50
  void VisitStmt_(const BufferStoreNode *op) final {
51
52
53
54
    // Check if the buffer is in global scope
    if (IsGlobalBuffer(op->buffer)) {
      CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
    }
55
56
57
    if (recursively_collect_conds_) {
      StmtExprVisitor::VisitStmt_(op);
    }
58
59
60
  }

  // Helper function to determine if a buffer is global
61
62
63
64
65
66
  bool IsGlobalBuffer(const Buffer &buffer) {
    // The storage scope is often encoded in the buffer->data var name or
    // associated attributes. In typical TVM IR, global buffers have scope
    // "global". Here we assume a helper function GetPtrStorageScope is
    // available. If not, you might need to parse buffer->data->name_hint or
    // associated attributes.
67
68
69
70
71
    String scope = buffer.scope();
    return scope == "global";
  }

  // Check each index against the buffer shape dimensions
72
73
  void CheckBufferIndices(const Buffer &buffer, const Array<PrimExpr> &indices,
                          bool is_load) {
74
75
    // Ensure indices count matches buffer dimension
    if (indices.size() != buffer->shape.size()) {
76
77
78
      LOG(WARNING) << "Buffer access dimension mismatch: indices size ("
                   << indices.size() << ") vs. shape size ("
                   << buffer->shape.size() << ")";
79
80
81
82
83
84
85
      return;
    }

    for (size_t i = 0; i < indices.size(); i++) {
      PrimExpr index = indices[i];
      PrimExpr shape_dim = buffer->shape[i];

86
      bool is_index_constant = true;
87
88
      PostOrderVisit(index, [&](const ObjectRef &obj) {
        if (const VarNode *v = obj.as<VarNode>()) {
89
90
91
92
          is_index_constant = false;
        }
        if (const BufferLoadNode *v = obj.as<BufferLoadNode>()) {
          is_index_constant = false;
93
94
        }
      });
95
      if (is_index_constant) {
96
        // If index is a constant, we can skip the check
97
98
99
        continue;
      }

100
101
102
      // We want to check if index < shape_dim can be proven.
      // If analyzer->CanProve(index < shape_dim) returns false,
      // it means we cannot prove the access is within bounds.
103
      PrimExpr upper_bound_cond = index < shape_dim;
104
105
      if (!analyzer_->CanProve(upper_bound_cond,
                               arith::ProofStrength::kSymbolicBound)) {
106
107
108
109
        _conditions.push_back(upper_bound_cond);
      }
      // Check if index >= 0 can be proven.
      PrimExpr lower_bound_cond = index >= 0;
110
111
      if (!analyzer_->CanProve(lower_bound_cond,
                               arith::ProofStrength::kSymbolicBound)) {
112
        _conditions.push_back(lower_bound_cond);
113
114
115
116
117
118
      }
    }
  }

  Array<PrimExpr> GetConditions() { return _conditions; }

119
private:
120
  Array<PrimExpr> _conditions;
121
  arith::Analyzer *analyzer_;
122
  bool recursively_collect_conds_;
123
124
};

125
class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
126
public:
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  // Static method to substitute and transform the given PrimFunc
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    // Create an instance of the legalizer with the analyzer
    SafeMemorysRewriter substituter(&analyzer);
    // Get a mutable copy of the function node
    PrimFuncNode *fptr = f.CopyOnWrite();
    for (const auto &[_, buffer] : f->buffer_map) {
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    // Apply the legalizer to the function body
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }
141

142
private:
143
144
145
146
147
  // Constructor initializing the base class with the analyzer
  SafeMemorysRewriter(arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer) {}
  // Constructor initializing the base class with the analyzer

148
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
149
    auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

    // For Load/Store, we only check the current node, not its children.
    // Since rewriter will recursively visit children.
    GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
    checker(load);
    Array<PrimExpr> conditions = checker.GetConditions();

    if (conditions.empty()) {
      return load;
    }

    // For loading, we can always use safe value if the access is out of
    // bounds
    PrimExpr value = load;
    for (auto cond : conditions) {
      ICHECK(cond.dtype() == DataType::Bool(1))
          << "condition is not a boolean: " << cond;
      value = if_then_else(cond, value, GetSafeValue(load->buffer));
    }
    return value;
  }

172
  Stmt VisitStmt_(const BufferStoreNode *op) final {
173
    // Check if the buffer is in global scope
174
    auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
175

176
    GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
177
178
    checker(store);
    Array<PrimExpr> conditions = checker.GetConditions();
179
180
181

    // Skip boundary check if the store value is an IfThenElse
    if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) {
182
      if (!conditions.empty()) {
183
184
185
186
187
188
189
190
191
192
193
        LOG(WARNING)
            << "Skipping boundary check for store with IfThenElse value: "
            << store->value
            << "\nAs manual boundary check detected, potential out-of-bounds "
               "access may occur."
            << "\nAuto detect boundaries are " << conditions;
        return store;
      }
      return store;
    }

194
    if (conditions.empty()) {
195
196
197
      return store;
    }

198
199
200
201
    // If a store is out of bounds, we skip the corresponding stmt directly.
    Stmt store_with_conditions = store;
    for (auto cond : conditions) {
      store_with_conditions = IfThenElse(cond, store_with_conditions);
202
    }
203
    return store_with_conditions;
204
205
  }

206
  // Recursively check Load/Store in the call arguments.
207
  // For example
208
209
  // T.call_extern("handle", "atomicAddx2", T.address_of(C),
  // T.address_of(C_shared))
210
211
212
213
214
215
216
217

  // NOTE(chaofan): This is currently not the most rigorous solution.
  // The check here is primarily intended to handle extern functions like
  // atomicAdd, which may involve memory access. Due to their special nature,
  // the BufferLoad in their parameters might be used for boundary checks of the
  // current statement. The current solution adopts a simplified approach:
  // directly applying the boundary constraints of all parameters to the
  // statement. While not entirely precise, it addresses most common scenarios.
218
  Stmt VisitStmt_(const EvaluateNode *op) final {
219
220
    auto evaluate = Downcast<Evaluate>(op);

221
    if (const CallNode *call_op = op->value.as<CallNode>()) {
222
      auto call = Downcast<Call>(op->value);
223
      if (call->op == builtin::call_extern()) {
224
225
226
227
        // For CallExtern, we recursively collect conditions from all children.
        // Since we cannot rewrite any BufferLoad in its children (Rewrite will
        // cause potential Nullptr exception).
        GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true);
228
229
230
        checker(call);
        Array<PrimExpr> conditions = checker.GetConditions();

231
        if (conditions.empty()) {
232
233
234
235
236
237
238
239
          return evaluate;
        }

        Stmt evaluate_with_conditions = evaluate;
        for (auto cond : conditions) {
          evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions);
        }
        return evaluate_with_conditions;
240
241
242
243
244
245
      }
    }

    return evaluate;
  }

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
  Stmt VisitStmt_(const BlockNode *op) final {
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kSafeValueMap)) {
      auto map = op->annotations.Get(attr::kSafeValueMap)
                     ->as<Map<Var, PrimExpr>>()
                     .value();
      for (const auto &[var, safe_value] : map) {
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block "
            << buffer_data_to_buffer_;
        auto buffer = buffer_data_to_buffer_[var];
        annotated_safe_value_map_.Set(buffer, safe_value);
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

265
266
  bool IsLocalBuffer(const Buffer &buffer) {
    String scope = buffer.scope();
267
268
    return scope == "local" || scope == "local.fragment" ||
           scope == "local.var";
269
270
  }

271
  bool isSharedBuffer(const Buffer &buffer) {
272
273
274
275
    String scope = buffer.scope();
    return scope == "shared" || scope == "shared.dyn";
  }

276
  bool IsGlobalBuffer(const Buffer &buffer) {
277
278
279
    String scope = buffer.scope();
    return scope == "global";
  }
280
281
282
283
  // Get the safe value of the buffer
  PrimExpr GetSafeValue(const Buffer &buffer) {
    if (annotated_safe_value_map_.count(buffer)) {
      return annotated_safe_value_map_[buffer];
284
285
286
287
288
    }
    return make_zero(buffer->dtype);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
289
  Map<Buffer, PrimExpr> annotated_safe_value_map_;
290
291
292
293
294
295
};

// Create a pass that legalizes vectorized loops in the IRModule
tvm::transform::Pass LegalizeSafeMemoryAccess() {
  using namespace tir::transform;
  // Define the transformation function to be applied
296
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
297
298
299
300
301
    bool disable_safe_memory_legalize =
        ctx->GetConfig<Bool>(kDisableSafeMemoryLegalize, Bool(false)).value();
    if (disable_safe_memory_legalize) {
      return f;
    }
302
    return SafeMemorysRewriter::Substitute(std::move(f));
303
304
305
306
307
308
  };
  // Create and return a PrimFunc pass with the transformation function
  return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {});
}

// Register the pass globally so it can be used in the compilation pipeline
309
TVM_FFI_STATIC_INIT_BLOCK() {
310
311
312
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
                        LegalizeSafeMemoryAccess);
313
}
314

315
316
} // namespace tl
} // namespace tvm