legalize_safe_memory_access.cc 11.9 KB
Newer Older
1
/*!
2
3
 * \file legalize_safe_memory_access.cc
 * \brief legalize safe memory access
4
5
6
7
8
9
10
11
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

12
#include "../op/builtin.h"
13
#include "../op/parallel.h"
14
#include "arith/ir_mutator_with_analyzer.h"
15
16
17
18
19
20
21
22
23
24
25
#include "loop_partition.h"
#include "loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;

// Helper class to find leaf For nodes in a given IR
class LeafForFinder : public StmtVisitor {
26
public:
27
28
  std::vector<For> leaf_for_nodes;

29
30
private:
  void VisitStmt_(const ForNode *op) final {
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    has_child_for_ = false;
    bool parent_has_child_for = parent_has_child_for_;
    parent_has_child_for_ = false;

    StmtVisitor::VisitStmt(op->body);

    if (!has_child_for_) {
      leaf_for_nodes.push_back(GetRef<For>(op));
    }

    parent_has_child_for_ = parent_has_child_for;
    parent_has_child_for_ = true;
  }

45
private:
46
47
48
49
50
51
52
53
54
55
56
57
58
  bool has_child_for_ = false;
  bool parent_has_child_for_ = false;
};

// We will create a visitor to check BufferLoad and BufferStore nodes
// within this loop body. This visitor will:
// 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope.
// 3. For each index, compare against the buffer's shape.
//    If the index might exceed the shape (upper bound too large),
//    log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {

59
  GlobalMemChecker(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
60
  void VisitExpr_(const BufferLoadNode *op) final {
61
62
63
64
65
66
67
    // Check if the buffer is in global scope
    if (IsGlobalBuffer(op->buffer)) {
      CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
    }
    StmtExprVisitor::VisitExpr_(op);
  }

68
  void VisitStmt_(const BufferStoreNode *op) final {
69
70
71
72
73
74
75
76
    // Check if the buffer is in global scope
    if (IsGlobalBuffer(op->buffer)) {
      CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  // Helper function to determine if a buffer is global
77
78
79
80
81
82
  bool IsGlobalBuffer(const Buffer &buffer) {
    // The storage scope is often encoded in the buffer->data var name or
    // associated attributes. In typical TVM IR, global buffers have scope
    // "global". Here we assume a helper function GetPtrStorageScope is
    // available. If not, you might need to parse buffer->data->name_hint or
    // associated attributes.
83
84
85
86
87
    String scope = buffer.scope();
    return scope == "global";
  }

  // Check each index against the buffer shape dimensions
88
89
  void CheckBufferIndices(const Buffer &buffer, const Array<PrimExpr> &indices,
                          bool is_load) {
90
91
    // Ensure indices count matches buffer dimension
    if (indices.size() != buffer->shape.size()) {
92
93
94
      LOG(WARNING) << "Buffer access dimension mismatch: indices size ("
                   << indices.size() << ") vs. shape size ("
                   << buffer->shape.size() << ")";
95
96
97
98
99
100
101
      return;
    }

    for (size_t i = 0; i < indices.size(); i++) {
      PrimExpr index = indices[i];
      PrimExpr shape_dim = buffer->shape[i];

102
103
104
105
106
107
108
109
110
111
      bool has_variable = false;
      PostOrderVisit(index, [&](const ObjectRef &obj) {
        if (const VarNode *v = obj.as<VarNode>()) {
          has_variable = true;
        }
      });
      if (!has_variable) {
        continue;
      }

112
113
114
      // We want to check if index < shape_dim can be proven.
      // If analyzer->CanProve(index < shape_dim) returns false,
      // it means we cannot prove the access is within bounds.
115
      PrimExpr upper_bound_cond = index < shape_dim;
116
117
      if (!analyzer_->CanProve(upper_bound_cond,
                               arith::ProofStrength::kSymbolicBound)) {
118
119
120
121
        _conditions.push_back(upper_bound_cond);
      }
      // Check if index >= 0 can be proven.
      PrimExpr lower_bound_cond = index >= 0;
122
123
      if (!analyzer_->CanProve(lower_bound_cond,
                               arith::ProofStrength::kSymbolicBound)) {
124
        _conditions.push_back(lower_bound_cond);
125
126
127
128
129
130
      }
    }
  }

  Array<PrimExpr> GetConditions() { return _conditions; }

131
private:
132
  Array<PrimExpr> _conditions;
133
  arith::Analyzer *analyzer_;
134
135
136
};

class SafeMemorysRewriter : public StmtExprMutator {
137
  arith::Analyzer *analyzer_;
138

139
public:
140
141
142
  explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
                               arith::Analyzer *analyzer)
      : annotated_padding_map_(annotated_padding_map), analyzer_(analyzer) {}
143

144
145
private:
  Stmt VisitStmt_(const BufferStoreNode *op) final {
146
147
    // Check if the buffer is in global scope
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
148

149
150
151
    GlobalMemChecker checker(analyzer_);
    checker(store);
    Array<PrimExpr> conditions = checker.GetConditions();
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    // Skip boundary check if the store value is an IfThenElse
    if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) {
      if (conditions.size() > 0) {
        LOG(WARNING)
            << "Skipping boundary check for store with IfThenElse value: "
            << store->value
            << "\nAs manual boundary check detected, potential out-of-bounds "
               "access may occur."
            << "\nAuto detect boundaries are " << conditions;
        return store;
      }
      return store;
    }

167
168
169
170
171
172
173
174
175
176
177
178
179
180
    if (conditions.size() == 0) {
      return store;
    }

    auto value = store->value;
    if (IsGlobalBuffer(store->buffer)) {
      Stmt store_with_conditions = store;
      for (auto cond : conditions) {
        store_with_conditions = IfThenElse(cond, store_with_conditions);
      }
      return store_with_conditions;
    } else if (isSharedBuffer(store->buffer)) {
      PrimExpr value = store->value;
      for (auto cond : conditions) {
181
182
        ICHECK(cond.dtype() == DataType::Bool(1))
            << "condition is not a boolean: " << cond;
183
        value = if_then_else(cond, value, GetPadding(store->buffer));
184
185
186
      }
      store.CopyOnWrite()->value = value;
      return store;
187
188
189
190
191
    } else if (IsLocalBuffer(store->buffer)) {
      PrimExpr value = store->value;
      for (auto cond : conditions) {
        ICHECK(cond.dtype() == DataType::Bool(1))
            << "condition is not a boolean: " << cond;
192
        value = if_then_else(cond, value, GetPadding(store->buffer));
193
194
195
196
197
198
      }
      store.CopyOnWrite()->value = value;
      return store;
    } else {
      LOG(FATAL) << "Check store buffer: " << store->buffer
                 << " is not a global or shared or local buffer";
199
200
201
202
203
204
205
    }

    return store;
  }

  // Handle Call Nodes
  // For example
206
207
208
  // T.call_extern("handle", "atomicAddx2", T.address_of(C),
  // T.address_of(C_shared))
  Stmt VisitStmt_(const EvaluateNode *op) final {
209
    auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    if (const CallNode *call_op = op->value.as<CallNode>()) {
      auto call = Downcast<Call>(evaluate->value);
      if (call->op == builtin::call_extern()) {
        GlobalMemChecker checker(analyzer_);
        checker(call);
        Array<PrimExpr> conditions = checker.GetConditions();

        if (conditions.size() == 0) {
          return evaluate;
        }

        Stmt evaluate_with_conditions = evaluate;
        for (auto cond : conditions) {
          evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions);
        }
        return evaluate_with_conditions;
226
227
228
229
230
231
      }
    }

    return evaluate;
  }

232
233
234
235
236
  bool IsLocalBuffer(const Buffer &buffer) {
    String scope = buffer.scope();
    return scope == "local" || scope == "local.fragment";
  }

237
  bool isSharedBuffer(const Buffer &buffer) {
238
239
240
241
    String scope = buffer.scope();
    return scope == "shared" || scope == "shared.dyn";
  }

242
  bool IsGlobalBuffer(const Buffer &buffer) {
243
244
245
    String scope = buffer.scope();
    return scope == "global";
  }
246
247
248
249
250
251
252
253
254
  // Get the padding of the buffer
  PrimExpr GetPadding(const Buffer &buffer) {
    if (annotated_padding_map_.count(buffer)) {
      return annotated_padding_map_[buffer];
    }
    return make_zero(buffer->dtype);
  }

  Map<Buffer, PrimExpr> annotated_padding_map_;
255
256
257
258
};

// Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
259
public:
260
261
262
263
264
265
  // Static method to substitute and transform the given PrimFunc
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    // Create an instance of the legalizer with the analyzer
    SafeMemoryLegalizer substituter(&analyzer);
    // Get a mutable copy of the function node
266
    PrimFuncNode *fptr = f.CopyOnWrite();
267
268
269
    for (const auto &[_, buffer] : f->buffer_map) {
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
270
271
272
273
274
    // Apply the legalizer to the function body
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

275
private:
276
  // Constructor initializing the base class with the analyzer
277
278
  SafeMemoryLegalizer(arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer) {}
279
280

  // Override the VisitStmt_ method to handle ForNode (loop statements)
281
  Stmt VisitStmt_(const ForNode *op) final {
282
283
284
285
    // Visit and potentially modify the loop node
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    auto has_inner_loop = HasInnerLoop(for_node->body);
    if (!has_inner_loop) {
286
      SafeMemorysRewriter rewriter(annotated_padding_map_, analyzer_);
287
      for_node.CopyOnWrite()->body = rewriter(for_node->body);
288
289
      // // Detect Buffer Load Node in the loop body, collect the indices and
      // buffer size
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

      // // Run the checker on the loop body
      // GlobalMemChecker checker(analyzer_);
      // checker(for_node->body);
      // Array<PrimExpr> conditions = checker.GetConditions();
      // auto body = for_node->body;
      // // Note that we might have duplicate conditions
      // // Which will be optimized by simplify pass
      // // Replace the loop body with the new body
      // for (auto cond : conditions) {
      //   body = IfThenElse(cond, body);
      // }
      // for_node.CopyOnWrite()->body = body;
      return std::move(for_node);
    }

    // Visit a For Node
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

310
311
312
313
314
315
316
317
318
319
  Stmt VisitStmt_(const BlockNode *op) final {
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kPaddingMap)) {
      auto map = op->annotations.Get(attr::kPaddingMap)
                     .as<Map<Var, PrimExpr>>()
                     .value();
      for (const auto &[var, padding] : map) {
        ICHECK(buffer_data_to_buffer_.count(var))
320
321
            << "buffer " << var << " is not found in the block "
            << buffer_data_to_buffer_;
322
323
324
325
326
327
328
        auto buffer = buffer_data_to_buffer_[var];
        annotated_padding_map_.Set(buffer, padding);
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

329
  static bool HasInnerLoop(const Stmt &stmt) {
330
331
332
333
    LeafForFinder finder;
    finder(stmt);
    return finder.leaf_for_nodes.size() > 0;
  }
334
335
336

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, PrimExpr> annotated_padding_map_;
337
338
339
340
341
342
343
};

// Create a pass that legalizes vectorized loops in the IRModule
tvm::transform::Pass LegalizeSafeMemoryAccess() {
  using namespace tir::transform;
  // Define the transformation function to be applied
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
344
345
346
347
348
    bool disable_safe_memory_legalize =
        ctx->GetConfig<Bool>(kDisableSafeMemoryLegalize, Bool(false)).value();
    if (disable_safe_memory_legalize) {
      return f;
    }
349
350
351
352
353
354
355
356
357
358
    return SafeMemoryLegalizer::Substitute(std::move(f));
  };
  // Create and return a PrimFunc pass with the transformation function
  return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {});
}

// Register the pass globally so it can be used in the compilation pipeline
TVM_REGISTER_GLOBAL("tl.transform.LegalizeSafeMemoryAccess")
    .set_body_typed(LegalizeSafeMemoryAccess);

359
360
} // namespace tl
} // namespace tvm