storage_access.cc 16.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file storage_access.cc
 */
#include "storage_access.h"

25
#include <tvm/arith/analyzer.h>
26
27
28
29
30
31
#include <tvm/target/target_info.h>
#include <tvm/tir/op.h>

#include <string>
#include <utility>

32
#include "../op/builtin.h"
33
34
35
36
37
38
39
40
41
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
  Var buf = op->buffer->data;
42
  buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
43
44
  StorageScope scope = GetScope(buf);
  if (Enabled(buf.get(), scope)) {
45
46
    ICHECK(allow_append_) << tvm::ffi::GetRef<BufferLoad>(op) << " "
                          << scope.to_string();
47
48
    AccessEntry e;
    e.threads = env_threads();
49
    e.thread_range = this->ComputeThreadRange(e.threads);
50
    e.buffer = buf;
51
    e.buffer_indices = op->indices;
52
53
54
55
56
57
58
59
60
    e.dtype = op->dtype.element_of();
    for (const auto &index : op->indices) {
      e.touched.push_back(arith::IntSet::Vector(index));
    }
    e.type = kRead;
    e.scope = scope;
    curr_stmt_.access.emplace_back(std::move(e));
  }
  // traverse child
61
  IRVisitorWithAnalyzer::VisitExpr_(op);
62
63
64
65
66
67
68
69
}

void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
  allow_append_ = true;
  ICHECK_EQ(curr_stmt_.access.size(), 0U);
  curr_stmt_.stmt = op;

  Var buf = op->buffer->data;
70
  buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
71
72
73
74
  StorageScope scope = GetScope(buf);
  if (Enabled(buf.get(), scope)) {
    AccessEntry e;
    e.threads = env_threads();
75
    e.thread_range = this->ComputeThreadRange(e.threads);
76
    e.buffer = buf;
77
    e.buffer_indices = op->indices;
78
79
80
81
82
83
84
85
86
    e.dtype = op->value.dtype().element_of();
    for (const auto &index : op->indices) {
      e.touched.push_back(arith::IntSet::Vector(index));
    }
    e.type = kWrite;
    e.scope = scope;
    curr_stmt_.access.emplace_back(std::move(e));
  }
  // traverse child
87
  IRVisitorWithAnalyzer::VisitStmt_(op);
88
89
90
91
92
93
94
95
96
97
98
  // push to the scope
  scope_.back().push_back(curr_stmt_);
  // clear access entry.
  curr_stmt_.access.clear();
  allow_append_ = false;
}

void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) {
  allow_append_ = true;
  ICHECK_EQ(curr_stmt_.access.size(), 0U);
  curr_stmt_.stmt = op;
99
  IRVisitorWithAnalyzer::VisitStmt_(op);
100
  // push to the scope
101
  if (!curr_stmt_.access.empty()) {
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    scope_.back().push_back(curr_stmt_);
    curr_stmt_.access.clear();
  }
  allow_append_ = false;
}

void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) {
  allow_append_ = true;
  ICHECK_EQ(curr_stmt_.access.size(), 0U);
  curr_stmt_.stmt = op;
  this->VisitExpr(op->value);
  // push to the scope
  scope_.back().push_back(curr_stmt_);
  // clear access entry.
  curr_stmt_.access.clear();
  allow_append_ = false;
  // traverse body block
  this->VisitStmt(op->body);
}

122
123
124
125
126
127
128
129
130
void TileLangStorageAccessVisitor::VisitStmt_(const BlockNode *op) {
  auto block = Downcast<Block>(op);
  for (const auto &buffer : block->alloc_buffers) {
    ICHECK(buffer->IsInstance<BufferNode>());
    buffer_data_to_buffer_.Set(buffer->data, buffer);
  }
  IRVisitorWithAnalyzer::VisitStmt_(op);
}

131
132
133
134
135
void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
  if (op->attr_key == tvm::tir::attr::double_buffer_write) {
    ICHECK(double_buffer_write_ == nullptr);
    double_buffer_write_ = op->node.as<VarNode>();
    scope_.push_back(std::vector<StmtEntry>());
136
    IRVisitorWithAnalyzer::VisitStmt_(op);
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    StmtEntry s;
    s.stmt = op;
    s.access = Summarize(std::move(scope_.back()), nullptr);
    scope_.pop_back();
    if (!s.access.empty()) {
      for (AccessEntry &e : s.access) {
        if (e.type == kWrite && e.buffer.get() == double_buffer_write_) {
          e.double_buffer_write = true;
        }
      }
      scope_.back().emplace_back(std::move(s));
    }
    double_buffer_write_ = nullptr;
  } else if (op->attr_key == tvm::tir::attr::coproc_scope) {
    IterVar iv = Downcast<IterVar>(op->node);
    env_threads_.push_back(iv);
153
    IRVisitorWithAnalyzer::VisitStmt_(op);
154
155
156
157
    env_threads_.pop_back();
  } else if (op->attr_key == tvm::tir::attr::thread_extent) {
    IterVar iv = Downcast<IterVar>(op->node);
    env_threads_.push_back(iv);
158
159
160
161
    ICHECK_NE(iv->thread_tag.length(), 0U);
    analyzer_.Bind(
        iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value));

162
163
164
    if (!in_device_env_) {
      in_device_env_ = true;
      scope_.push_back(std::vector<StmtEntry>());
165
      IRVisitorWithAnalyzer::VisitStmt_(op);
166
167
168
169
170
      // no need to take the result as the thread barrier automatically syncs.
      Summarize(std::move(scope_.back()), nullptr);
      in_device_env_ = false;
      scope_.pop_back();
    } else {
171
      IRVisitorWithAnalyzer::VisitStmt_(op);
172
173
174
175
176
177
178
    }
    env_threads_.pop_back();
  } else if (op->attr_key == tvm::tir::attr::hand_threaded) {
    // skip this pass on blocks that were hand_threaded
    // this avoids control flow and read/write conflicts
    // between hand-threaded kernels and automatic threading
  } else {
179
    IRVisitorWithAnalyzer::VisitStmt_(op);
180
181
182
183
184
  }
}

void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) {
  scope_.push_back(std::vector<StmtEntry>());
185
  IRVisitorWithAnalyzer::VisitStmt_(op);
186
187
188
189
  StmtEntry s;
  s.stmt = op;
  s.access = Summarize(std::move(scope_.back()), op);
  scope_.pop_back();
190
  if (!s.access.empty()) {
191
192
193
194
195
196
    // relax the touched set to contain all ranges in the loop.
    std::unordered_map<const VarNode *, arith::IntSet> relax_map;
    relax_map[op->loop_var.get()] =
        arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
    for (AccessEntry &e : s.access) {
      if (e.buffer.defined()) {
197
        ICHECK(!e.touched.empty());
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        Array<arith::IntSet> new_touched;
        for (const auto &touched : e.touched) {
          new_touched.push_back(arith::EvalSet(touched, relax_map));
        }
        e.touched = std::move(new_touched);
      }
    }
  }
  if (!s.access.empty()) {
    scope_.back().emplace_back(std::move(s));
  }
}

bool IsThreadInvariant(const PrimExpr &cond) {
  if (auto call = cond.as<CallNode>()) {
    if (auto opt_call_op = call->op.as<Op>()) {
214
      const auto &call_op = opt_call_op.value();
215
216
217
218
219
220
221
222
      if (call_op.same_as(builtin::tvm_thread_invariant())) {
        return true;
      }
    }
  }
  return false;
}

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
/**
 * @brief Visit an IfThenElse statement and collect storage access summaries for
 * its branches.
 *
 * Visits the if-then-else node's condition and both branches to summarize
 * buffer reads, writes, and synchronization events under the condition's
 * constraints. If the condition is not thread-invariant, increments an internal
 * condition counter for the duration of processing.
 *
 * Behavior and side effects:
 * - Evaluates the condition expression (using ExtractRealCondition) and applies
 * it as a constraint while summarizing the then-branch.
 * - For the else-branch (when present), applies the negated,
 * analyzer-simplified condition
 *   (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint.
 * - Accumulates summarized StmtEntry access information for the then/else
 * branches and appends a combined StmtEntry for the IfThenElseNode into the
 * current scope.
 * - Temporarily toggles allow_append_ and clears curr_stmt_.access during
 * condition evaluation and branch summarization.
 * - Modifies internal state: scope_ (push/pop of temporary branch scopes),
 * curr_stmt_.access, and condition_counter_ (incremented/decremented when the
 * condition is not thread-invariant).
 */
247
248
249
250
251
void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
  bool is_thread_invariant = IsThreadInvariant(op->condition);
  if (!is_thread_invariant) {
    ++condition_counter_;
  }
252
253

  allow_append_ = true;
254
  this->VisitExpr(op->condition);
255
256
  PrimExpr real_condition = ExtractRealCondition(op->condition);

257
258
259
  curr_stmt_.access.clear();
  allow_append_ = false;

260
  scope_.push_back(std::vector<StmtEntry>());
261
262
263
264
265
  {
    With<arith::ConstraintContext> constraint(&analyzer_, real_condition);
    this->VisitStmt(op->then_case);
  }

266
267
268
269
270
271
  StmtEntry s;
  s.stmt = op;
  s.access = Summarize(std::move(scope_.back()), nullptr);
  scope_.pop_back();
  if (op->else_case) {
    scope_.push_back(std::vector<StmtEntry>());
272
    {
273
274
      With<arith::ConstraintContext> constraint(
          &analyzer_, analyzer_.rewrite_simplify(Not(real_condition)));
275
276
      this->VisitStmt(op->else_case.value());
    }
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    auto v = Summarize(std::move(scope_.back()), nullptr);
    scope_.pop_back();
    s.access.insert(s.access.end(), v.begin(), v.end());
  }
  scope_.back().emplace_back(std::move(s));
  if (!is_thread_invariant) {
    --condition_counter_;
  }
}

void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
  bool is_thread_invariant = IsThreadInvariant(op->condition);
  if (!is_thread_invariant) {
    ++condition_counter_;
  }
  this->VisitExpr(op->condition);
  scope_.push_back(std::vector<StmtEntry>());
  this->VisitStmt(op->body);
  StmtEntry s;
  s.stmt = op;
  s.access = Summarize(std::move(scope_.back()), nullptr);
  scope_.pop_back();
  scope_.back().emplace_back(std::move(s));
  if (!is_thread_invariant) {
    --condition_counter_;
  }
}

void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
  // Mark async TMA load context so that tvm_access_ptr within the call
  // can be tagged accordingly.
  auto is_tma_load = [&]() {
    if (auto opt = op->op.as<Op>()) {
      const Op &call_op = opt.value();
      return call_op.same_as(tl::tma_load()) ||
             call_op.same_as(tl::tma_load_im2col());
    }
    return false;
  }();
  if (is_tma_load) {
    tma_depth_++;
    for (const auto &a : op->args) {
      this->VisitExpr(a);
    }
    tma_depth_--;
    return;
  }
324
325
  if (op->op.same_as(builtin::address_of())) {
    ICHECK_EQ(op->args.size(), 1U);
326
327
328
329
    if (auto load = op->args[0].as<BufferLoadNode>()) {
      Buffer buffer = load->buffer;
      DataType dtype = buffer->dtype;
      const VarNode *buffer_var = buffer->data.as<VarNode>();
330
331
      buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buffer_var), buffer);
      StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
332
333
334
      Array<Range> buffer_ranges;
      // from indices to buffer indices
      ICHECK(buffer->shape.size() == load->indices.size());
335
336
      // Use buffer shape and indices to compute the buffer_ranges for each
      // dimension.
337
      for (size_t i = 0; i < buffer->shape.size(); ++i) {
338
339
340
        PrimExpr min = load->indices[i];
        PrimExpr extent = make_const(buffer->shape[i].dtype(), 1);
        buffer_ranges.push_back(Range::FromMinExtent(min, extent));
341
      }
342
343
344
345
      if (Enabled(buffer_var, scope)) {
        ICHECK(allow_append_);
        AccessEntry e;
        e.threads = env_threads();
346
        e.thread_range = this->ComputeThreadRange(e.threads);
347
348
        e.dtype = dtype;
        e.buffer = Downcast<Var>(buffer->data);
349
        e.buffer_ranges = buffer_ranges;
350
351
352
        for (const auto &index : load->indices) {
          e.touched.push_back(arith::IntSet::Vector(index));
        }
353
        e.is_pointer_access = true;
354
355
356
        e.type = kRead;
        e.scope = scope;
        curr_stmt_.access.emplace_back(e);
357
      }
358
      IRVisitorWithAnalyzer::VisitExpr_(load);
359
    } else {
360
      IRVisitorWithAnalyzer::VisitExpr_(op);
361
362
363
364
    }
  } else if (op->op.same_as(builtin::tvm_access_ptr())) {
    ICHECK_EQ(op->args.size(), 5U);
    DataType dtype = op->args[0].dtype();
365
    const VarNode *buffer_var = op->args[1].as<VarNode>();
366
367
368
    PrimExpr offset = op->args[2];
    PrimExpr extent = op->args[3];
    const IntImmNode *flag = op->args[4].as<IntImmNode>();
369
    StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
370
    // The buffer scope.
371
    if (Enabled(buffer_var, scope)) {
372
      ICHECK(allow_append_);
373
      Array<Range> buffer_ranges;
374
      if (buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var)) ==
375
376
377
378
          buffer_data_to_buffer_.end()) {
        // cannot find buffer map, use the default buffer
        buffer_ranges = {Range::FromMinExtent(offset, extent)};
      } else {
379
380
        Buffer buffer =
            buffer_data_to_buffer_.at(tvm::ffi::GetRef<Var>(buffer_var));
381
382
383
384
385
        auto buffer_shape = buffer->shape;
        // convert 1d offset to multi-dimensional index
        auto linear_to_indices = [this](PrimExpr offset,
                                        const Array<PrimExpr> &shape) {
          Array<PrimExpr> indices;
386
          PrimExpr remaining = std::move(offset);
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
          for (size_t i = 0; i < shape.size(); ++i) {
            PrimExpr stride = make_const(DataType::Int(32), 1);
            for (size_t j = i + 1; j < shape.size(); ++j) {
              stride = stride * shape[j];
            }
            PrimExpr idx = FloorDiv(remaining, stride);
            remaining = FloorMod(remaining, stride);
            indices.push_back(analyzer_.Simplify(idx));
          }
          return indices;
        };
        Array<PrimExpr> start_indices = linear_to_indices(offset, buffer_shape);
        Array<PrimExpr> end_indices =
            linear_to_indices(offset + extent, buffer_shape);
        for (size_t i = 0; i < buffer_shape.size(); ++i) {
          buffer_ranges.push_back(Range::FromMinExtent(
              start_indices[i],
              analyzer_.Simplify(end_indices[i] - start_indices[i])));
        }
      }
407
408
      AccessEntry e;
      e.threads = env_threads();
409
      e.thread_range = this->ComputeThreadRange(e.threads);
410
      e.dtype = dtype;
411
      e.buffer = tvm::ffi::GetRef<Var>(buffer_var);
412
413
      e.buffer_ranges = buffer_ranges;
      e.is_pointer_access = true;
414
415
416
417
418
      e.touched = {
          arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};
      e.scope = scope;
      if (flag->value & 1) {
        e.type = kRead;
419
        e.is_async_copy = (tma_depth_ > 0);
420
421
422
423
        curr_stmt_.access.emplace_back(e);
      }
      if (flag->value & 2) {
        e.type = kWrite;
424
        e.is_async_copy = (tma_depth_ > 0);
425
426
427
        curr_stmt_.access.emplace_back(e);
      }
    }
428
    IRVisitorWithAnalyzer::VisitExpr_(op);
429
430
431
432
433
434
435
  } else if (op->op.same_as(builtin::tvm_storage_sync())) {
    ICHECK(allow_append_);
    const std::string &s = op->args[0].as<StringImmNode>()->value;
    if (s != "warp") {
      StorageScope scope = StorageScope::Create(s);
      AccessEntry e;
      e.threads = env_threads();
436
      e.thread_range = this->ComputeThreadRange(e.threads);
437
438
439
440
441
      e.type = kSync;
      e.scope = StorageScope::Create(s);
      curr_stmt_.access.emplace_back(std::move(e));
    }
  } else {
442
443
444
445
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }
}

446
447
Map<Var, Range> TileLangStorageAccessVisitor::ComputeThreadRange(
    const Array<IterVar> &threads) {
448
449
450
451
452
453
454
455
456
457
458
459
460
  Map<Var, Range> thread_range;
  for (const auto &th : threads) {
    auto thread_tag = th->thread_tag;
    if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" ||
        thread_tag == "threadIdx.z") {
      auto const_int_bound = analyzer_.const_int_bound(th->var);
      auto min_value = const_int_bound->min_value;
      auto max_value = const_int_bound->max_value;
      auto extent = max_value - min_value + 1;
      auto dtype = th->var.dtype();
      thread_range.Set(th->var, Range::FromMinExtent(IntImm(dtype, min_value),
                                                     IntImm(dtype, extent)));
    }
461
  }
462
  return thread_range;
463
464
}

465
466
StorageScope
TileLangStorageAccessVisitor::GetScope(const Var &buffer_var) const {
467
468
469
470
471
472
473
474
  if (buffer_var->type_annotation.as<PointerTypeNode>()) {
    return StorageScope::Create(GetPtrStorageScope(buffer_var));
  }
  return StorageScope(); // global by default
}

} // namespace tl
} // namespace tvm