storage_access.cc 16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file storage_access.cc
 */
#include "storage_access.h"

25
#include <tvm/arith/analyzer.h>
26
27
28
29
30
31
#include <tvm/target/target_info.h>
#include <tvm/tir/op.h>

#include <string>
#include <utility>

32
#include "../op/builtin.h"
33
34
35
36
37
38
39
40
41
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
  Var buf = op->buffer->data;
42
  buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
43
44
  StorageScope scope = GetScope(buf);
  if (Enabled(buf.get(), scope)) {
45
    ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string();
46
47
    AccessEntry e;
    e.threads = env_threads();
48
    e.thread_range = this->ComputeThreadRange(e.threads);
49
    e.buffer = buf;
50
    e.buffer_indices = op->indices;
51
52
53
54
55
56
57
58
59
    e.dtype = op->dtype.element_of();
    for (const auto &index : op->indices) {
      e.touched.push_back(arith::IntSet::Vector(index));
    }
    e.type = kRead;
    e.scope = scope;
    curr_stmt_.access.emplace_back(std::move(e));
  }
  // traverse child
60
  IRVisitorWithAnalyzer::VisitExpr_(op);
61
62
63
64
65
66
67
68
}

void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
  allow_append_ = true;
  ICHECK_EQ(curr_stmt_.access.size(), 0U);
  curr_stmt_.stmt = op;

  Var buf = op->buffer->data;
69
  buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
70
71
72
73
  StorageScope scope = GetScope(buf);
  if (Enabled(buf.get(), scope)) {
    AccessEntry e;
    e.threads = env_threads();
74
    e.thread_range = this->ComputeThreadRange(e.threads);
75
    e.buffer = buf;
76
    e.buffer_indices = op->indices;
77
78
79
80
81
82
83
84
85
    e.dtype = op->value.dtype().element_of();
    for (const auto &index : op->indices) {
      e.touched.push_back(arith::IntSet::Vector(index));
    }
    e.type = kWrite;
    e.scope = scope;
    curr_stmt_.access.emplace_back(std::move(e));
  }
  // traverse child
86
  IRVisitorWithAnalyzer::VisitStmt_(op);
87
88
89
90
91
92
93
94
95
96
97
  // push to the scope
  scope_.back().push_back(curr_stmt_);
  // clear access entry.
  curr_stmt_.access.clear();
  allow_append_ = false;
}

void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) {
  allow_append_ = true;
  ICHECK_EQ(curr_stmt_.access.size(), 0U);
  curr_stmt_.stmt = op;
98
  IRVisitorWithAnalyzer::VisitStmt_(op);
99
  // push to the scope
100
  if (!curr_stmt_.access.empty()) {
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    scope_.back().push_back(curr_stmt_);
    curr_stmt_.access.clear();
  }
  allow_append_ = false;
}

void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) {
  allow_append_ = true;
  ICHECK_EQ(curr_stmt_.access.size(), 0U);
  curr_stmt_.stmt = op;
  this->VisitExpr(op->value);
  // push to the scope
  scope_.back().push_back(curr_stmt_);
  // clear access entry.
  curr_stmt_.access.clear();
  allow_append_ = false;
  // traverse body block
  this->VisitStmt(op->body);
}

121
122
123
124
125
126
127
128
129
void TileLangStorageAccessVisitor::VisitStmt_(const BlockNode *op) {
  auto block = Downcast<Block>(op);
  for (const auto &buffer : block->alloc_buffers) {
    ICHECK(buffer->IsInstance<BufferNode>());
    buffer_data_to_buffer_.Set(buffer->data, buffer);
  }
  IRVisitorWithAnalyzer::VisitStmt_(op);
}

130
131
132
133
134
void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
  if (op->attr_key == tvm::tir::attr::double_buffer_write) {
    ICHECK(double_buffer_write_ == nullptr);
    double_buffer_write_ = op->node.as<VarNode>();
    scope_.push_back(std::vector<StmtEntry>());
135
    IRVisitorWithAnalyzer::VisitStmt_(op);
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    StmtEntry s;
    s.stmt = op;
    s.access = Summarize(std::move(scope_.back()), nullptr);
    scope_.pop_back();
    if (!s.access.empty()) {
      for (AccessEntry &e : s.access) {
        if (e.type == kWrite && e.buffer.get() == double_buffer_write_) {
          e.double_buffer_write = true;
        }
      }
      scope_.back().emplace_back(std::move(s));
    }
    double_buffer_write_ = nullptr;
  } else if (op->attr_key == tvm::tir::attr::coproc_scope) {
    IterVar iv = Downcast<IterVar>(op->node);
    env_threads_.push_back(iv);
152
    IRVisitorWithAnalyzer::VisitStmt_(op);
153
154
155
156
    env_threads_.pop_back();
  } else if (op->attr_key == tvm::tir::attr::thread_extent) {
    IterVar iv = Downcast<IterVar>(op->node);
    env_threads_.push_back(iv);
157
158
159
160
    ICHECK_NE(iv->thread_tag.length(), 0U);
    analyzer_.Bind(
        iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value));

161
162
163
    if (!in_device_env_) {
      in_device_env_ = true;
      scope_.push_back(std::vector<StmtEntry>());
164
      IRVisitorWithAnalyzer::VisitStmt_(op);
165
166
167
168
169
      // no need to take the result as the thread barrier automatically syncs.
      Summarize(std::move(scope_.back()), nullptr);
      in_device_env_ = false;
      scope_.pop_back();
    } else {
170
      IRVisitorWithAnalyzer::VisitStmt_(op);
171
172
173
174
175
176
177
    }
    env_threads_.pop_back();
  } else if (op->attr_key == tvm::tir::attr::hand_threaded) {
    // skip this pass on blocks that were hand_threaded
    // this avoids control flow and read/write conflicts
    // between hand-threaded kernels and automatic threading
  } else {
178
    IRVisitorWithAnalyzer::VisitStmt_(op);
179
180
181
182
183
  }
}

void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) {
  scope_.push_back(std::vector<StmtEntry>());
184
  IRVisitorWithAnalyzer::VisitStmt_(op);
185
186
187
188
  StmtEntry s;
  s.stmt = op;
  s.access = Summarize(std::move(scope_.back()), op);
  scope_.pop_back();
189
  if (!s.access.empty()) {
190
191
192
193
194
195
    // relax the touched set to contain all ranges in the loop.
    std::unordered_map<const VarNode *, arith::IntSet> relax_map;
    relax_map[op->loop_var.get()] =
        arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
    for (AccessEntry &e : s.access) {
      if (e.buffer.defined()) {
196
        ICHECK(!e.touched.empty());
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        Array<arith::IntSet> new_touched;
        for (const auto &touched : e.touched) {
          new_touched.push_back(arith::EvalSet(touched, relax_map));
        }
        e.touched = std::move(new_touched);
      }
    }
  }
  if (!s.access.empty()) {
    scope_.back().emplace_back(std::move(s));
  }
}

bool IsThreadInvariant(const PrimExpr &cond) {
  if (auto call = cond.as<CallNode>()) {
    if (auto opt_call_op = call->op.as<Op>()) {
213
      const auto &call_op = opt_call_op.value();
214
215
216
217
218
219
220
221
      if (call_op.same_as(builtin::tvm_thread_invariant())) {
        return true;
      }
    }
  }
  return false;
}

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
/**
 * @brief Visit an IfThenElse statement and collect storage access summaries for
 * its branches.
 *
 * Visits the if-then-else node's condition and both branches to summarize
 * buffer reads, writes, and synchronization events under the condition's
 * constraints. If the condition is not thread-invariant, increments an internal
 * condition counter for the duration of processing.
 *
 * Behavior and side effects:
 * - Evaluates the condition expression (using ExtractRealCondition) and applies
 * it as a constraint while summarizing the then-branch.
 * - For the else-branch (when present), applies the negated,
 * analyzer-simplified condition
 *   (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint.
 * - Accumulates summarized StmtEntry access information for the then/else
 * branches and appends a combined StmtEntry for the IfThenElseNode into the
 * current scope.
 * - Temporarily toggles allow_append_ and clears curr_stmt_.access during
 * condition evaluation and branch summarization.
 * - Modifies internal state: scope_ (push/pop of temporary branch scopes),
 * curr_stmt_.access, and condition_counter_ (incremented/decremented when the
 * condition is not thread-invariant).
 */
246
247
248
249
250
void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
  bool is_thread_invariant = IsThreadInvariant(op->condition);
  if (!is_thread_invariant) {
    ++condition_counter_;
  }
251
252

  allow_append_ = true;
253
  this->VisitExpr(op->condition);
254
255
  PrimExpr real_condition = ExtractRealCondition(op->condition);

256
257
258
  curr_stmt_.access.clear();
  allow_append_ = false;

259
  scope_.push_back(std::vector<StmtEntry>());
260
261
262
263
264
  {
    With<arith::ConstraintContext> constraint(&analyzer_, real_condition);
    this->VisitStmt(op->then_case);
  }

265
266
267
268
269
270
  StmtEntry s;
  s.stmt = op;
  s.access = Summarize(std::move(scope_.back()), nullptr);
  scope_.pop_back();
  if (op->else_case) {
    scope_.push_back(std::vector<StmtEntry>());
271
    {
272
273
      With<arith::ConstraintContext> constraint(
          &analyzer_, analyzer_.rewrite_simplify(Not(real_condition)));
274
275
      this->VisitStmt(op->else_case.value());
    }
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    auto v = Summarize(std::move(scope_.back()), nullptr);
    scope_.pop_back();
    s.access.insert(s.access.end(), v.begin(), v.end());
  }
  scope_.back().emplace_back(std::move(s));
  if (!is_thread_invariant) {
    --condition_counter_;
  }
}

void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
  bool is_thread_invariant = IsThreadInvariant(op->condition);
  if (!is_thread_invariant) {
    ++condition_counter_;
  }
  this->VisitExpr(op->condition);
  scope_.push_back(std::vector<StmtEntry>());
  this->VisitStmt(op->body);
  StmtEntry s;
  s.stmt = op;
  s.access = Summarize(std::move(scope_.back()), nullptr);
  scope_.pop_back();
  scope_.back().emplace_back(std::move(s));
  if (!is_thread_invariant) {
    --condition_counter_;
  }
}

void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
  // Mark async TMA load context so that tvm_access_ptr within the call
  // can be tagged accordingly.
  auto is_tma_load = [&]() {
    if (auto opt = op->op.as<Op>()) {
      const Op &call_op = opt.value();
      return call_op.same_as(tl::tma_load()) ||
             call_op.same_as(tl::tma_load_im2col());
    }
    return false;
  }();
  if (is_tma_load) {
    tma_depth_++;
    for (const auto &a : op->args) {
      this->VisitExpr(a);
    }
    tma_depth_--;
    return;
  }
323
324
  if (op->op.same_as(builtin::address_of())) {
    ICHECK_EQ(op->args.size(), 1U);
325
326
327
328
    if (auto load = op->args[0].as<BufferLoadNode>()) {
      Buffer buffer = load->buffer;
      DataType dtype = buffer->dtype;
      const VarNode *buffer_var = buffer->data.as<VarNode>();
329
      buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer);
330
      StorageScope scope = GetScope(GetRef<Var>(buffer_var));
331
332
333
      Array<Range> buffer_ranges;
      // from indices to buffer indices
      ICHECK(buffer->shape.size() == load->indices.size());
334
335
      // Use buffer shape and indices to compute the buffer_ranges for each
      // dimension.
336
      for (size_t i = 0; i < buffer->shape.size(); ++i) {
337
338
339
        PrimExpr min = load->indices[i];
        PrimExpr extent = make_const(buffer->shape[i].dtype(), 1);
        buffer_ranges.push_back(Range::FromMinExtent(min, extent));
340
      }
341
342
343
344
      if (Enabled(buffer_var, scope)) {
        ICHECK(allow_append_);
        AccessEntry e;
        e.threads = env_threads();
345
        e.thread_range = this->ComputeThreadRange(e.threads);
346
347
        e.dtype = dtype;
        e.buffer = Downcast<Var>(buffer->data);
348
        e.buffer_ranges = buffer_ranges;
349
350
351
        for (const auto &index : load->indices) {
          e.touched.push_back(arith::IntSet::Vector(index));
        }
352
        e.is_pointer_access = true;
353
354
355
        e.type = kRead;
        e.scope = scope;
        curr_stmt_.access.emplace_back(e);
356
      }
357
      IRVisitorWithAnalyzer::VisitExpr_(load);
358
    } else {
359
      IRVisitorWithAnalyzer::VisitExpr_(op);
360
361
362
363
    }
  } else if (op->op.same_as(builtin::tvm_access_ptr())) {
    ICHECK_EQ(op->args.size(), 5U);
    DataType dtype = op->args[0].dtype();
364
    const VarNode *buffer_var = op->args[1].as<VarNode>();
365
366
367
    PrimExpr offset = op->args[2];
    PrimExpr extent = op->args[3];
    const IntImmNode *flag = op->args[4].as<IntImmNode>();
368
    StorageScope scope = GetScope(GetRef<Var>(buffer_var));
369
    // The buffer scope.
370
    if (Enabled(buffer_var, scope)) {
371
      ICHECK(allow_append_);
372
373
374
375
376
377
378
379
380
381
382
383
      Array<Range> buffer_ranges;
      if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) ==
          buffer_data_to_buffer_.end()) {
        // cannot find buffer map, use the default buffer
        buffer_ranges = {Range::FromMinExtent(offset, extent)};
      } else {
        Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
        auto buffer_shape = buffer->shape;
        // convert 1d offset to multi-dimensional index
        auto linear_to_indices = [this](PrimExpr offset,
                                        const Array<PrimExpr> &shape) {
          Array<PrimExpr> indices;
384
          PrimExpr remaining = std::move(offset);
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
          for (size_t i = 0; i < shape.size(); ++i) {
            PrimExpr stride = make_const(DataType::Int(32), 1);
            for (size_t j = i + 1; j < shape.size(); ++j) {
              stride = stride * shape[j];
            }
            PrimExpr idx = FloorDiv(remaining, stride);
            remaining = FloorMod(remaining, stride);
            indices.push_back(analyzer_.Simplify(idx));
          }
          return indices;
        };
        Array<PrimExpr> start_indices = linear_to_indices(offset, buffer_shape);
        Array<PrimExpr> end_indices =
            linear_to_indices(offset + extent, buffer_shape);
        for (size_t i = 0; i < buffer_shape.size(); ++i) {
          buffer_ranges.push_back(Range::FromMinExtent(
              start_indices[i],
              analyzer_.Simplify(end_indices[i] - start_indices[i])));
        }
      }
405
406
      AccessEntry e;
      e.threads = env_threads();
407
      e.thread_range = this->ComputeThreadRange(e.threads);
408
      e.dtype = dtype;
409
410
411
      e.buffer = GetRef<Var>(buffer_var);
      e.buffer_ranges = buffer_ranges;
      e.is_pointer_access = true;
412
413
414
415
416
      e.touched = {
          arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};
      e.scope = scope;
      if (flag->value & 1) {
        e.type = kRead;
417
        e.is_async_copy = (tma_depth_ > 0);
418
419
420
421
        curr_stmt_.access.emplace_back(e);
      }
      if (flag->value & 2) {
        e.type = kWrite;
422
        e.is_async_copy = (tma_depth_ > 0);
423
424
425
        curr_stmt_.access.emplace_back(e);
      }
    }
426
    IRVisitorWithAnalyzer::VisitExpr_(op);
427
428
429
430
431
432
433
  } else if (op->op.same_as(builtin::tvm_storage_sync())) {
    ICHECK(allow_append_);
    const std::string &s = op->args[0].as<StringImmNode>()->value;
    if (s != "warp") {
      StorageScope scope = StorageScope::Create(s);
      AccessEntry e;
      e.threads = env_threads();
434
      e.thread_range = this->ComputeThreadRange(e.threads);
435
436
437
438
439
      e.type = kSync;
      e.scope = StorageScope::Create(s);
      curr_stmt_.access.emplace_back(std::move(e));
    }
  } else {
440
441
442
443
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }
}

444
445
Map<Var, Range> TileLangStorageAccessVisitor::ComputeThreadRange(
    const Array<IterVar> &threads) {
446
447
448
449
450
451
452
453
454
455
456
457
458
  Map<Var, Range> thread_range;
  for (const auto &th : threads) {
    auto thread_tag = th->thread_tag;
    if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" ||
        thread_tag == "threadIdx.z") {
      auto const_int_bound = analyzer_.const_int_bound(th->var);
      auto min_value = const_int_bound->min_value;
      auto max_value = const_int_bound->max_value;
      auto extent = max_value - min_value + 1;
      auto dtype = th->var.dtype();
      thread_range.Set(th->var, Range::FromMinExtent(IntImm(dtype, min_value),
                                                     IntImm(dtype, extent)));
    }
459
  }
460
  return thread_range;
461
462
}

463
464
StorageScope
TileLangStorageAccessVisitor::GetScope(const Var &buffer_var) const {
465
466
467
468
469
470
471
472
  if (buffer_var->type_annotation.as<PointerTypeNode>()) {
    return StorageScope::Create(GetPtrStorageScope(buffer_var));
  }
  return StorageScope(); // global by default
}

} // namespace tl
} // namespace tvm