inject_tma_barrier.cc 21.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tma_barrier_rewriter.cc
 * \brief Rewrite TMA barriers for cuda GPU (sm90+)
 */

#include <tvm/arith/analyzer.h>
26
#include <tvm/ffi/reflection/registry.h>
27
28
29
30
31
32
33
34
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

35
36
#include <utility>

37
#include "../op/builtin.h"
38
39
#include "./common/attr.h"
#include "./common/collector.h"
40
41
42
43
44
45
46
47
48
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"

namespace tvm {
namespace tl {

using namespace tir;
using namespace tir::transform;
using arith::IRMutatorWithAnalyzer;
49
using arith::IRVisitorWithAnalyzer;
50
51
52
53
54
55
56
57
58
59

class TmaTraitsCollector : public StmtExprVisitor {
public:
  TmaTraitsCollector() { Initialize(); }

  void Initialize() {
    bulk_copy_bytes = 0;
    loop_extents = 1;
  }

60
  void Collect(const Stmt &stmt) { VisitStmt(stmt); }
61
62
63
64
65
66

  PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }

private:
  void VisitExpr_(const CallNode *call) final {
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
67
68
69
70
71
72
73
74
75
76
77
      auto arg0 = call->args[0].as<Call>();
      if (call->op.same_as(tma_load()) && arg0 &&
          !arg0.value()->op.same_as(create_tma_descriptor())) {
        // 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
        bulk_copy_bytes = call->args[3] * loop_extents;
      } else {
        Call access_ptr = Downcast<Call>(call->args[2]);
        ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
        int type_bytes = access_ptr->args[0]->dtype.bytes();
        bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
      }
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    }
    StmtExprVisitor::VisitExpr_(call);
  }

  void VisitStmt_(const ForNode *op) final {
    PrimExpr old_loop_evtents = loop_extents;
    loop_extents *= op->extent;
    StmtExprVisitor::VisitStmt_(op);
    loop_extents = old_loop_evtents;
  }

  PrimExpr bulk_copy_bytes = 0;
  PrimExpr loop_extents = 1;
};

class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
public:
  static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
    TmaExpectTxRewriter rewriter(analyzer);
    f.CopyOnWrite()->body = rewriter(f->body);
    return f;
  }

private:
  bool inside_tma_block_{false};
  bool visited_tma_load_{false};
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);

  PrimExpr makeGetBarrier(PrimExpr barrier_id) {
108
    return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)});
109
110
111
112
  }

  Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
    auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
113
                     {makeGetBarrier(std::move(barrier_id)), std::move(bytes)});
114
115
116
117
118
119
120
    return Evaluate(call);
  }

  TmaExpectTxRewriter(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const AttrStmtNode *op) final {
121

122
123
124
125
126
127
128
129
130
131
132
133
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  Stmt VisitStmt_(const IfThenElseNode *op) {
    // Check if this is the TMA block
134
135
136
137
138
    bool flag = false;
    if (op->condition.as<CallNode>()) {
      flag = op->condition.as<CallNode>()->op.same_as(tl_shuffle_elect());
    }
    if (op->condition.as<EQNode>() || flag) {
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
      Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op);

      if (visited_tma_load_) {
        auto then_case = op->then_case;
        TmaTraitsCollector collector;
        collector.Collect(then_case);

        Array<Stmt> stmts;
        if (!is_zero(collector.BulkCopyBytes())) {
          auto expect_tx = makeExpectTX(0, collector.BulkCopyBytes());
          stmts.push_back(expect_tx);
        }
        stmts.push_back(then_case);
        if (stmts.size() == 1) {
          return IfThenElse(op->condition, stmts[0], op->else_case);
        } else {
          auto seq_stmt = SeqStmt(stmts);
          return IfThenElse(op->condition, seq_stmt, op->else_case);
        }
      }
      visited_tma_load_ = false;
      return ret;
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  PrimExpr VisitExpr_(const CallNode *op) {
166
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
167
168
169
170
      auto arg0 = op->args[0].as<Call>();
      bool is_1d_tma_load =
          arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
          op->op.same_as(tma_load());
171
172
      visited_tma_load_ = true;
      Array<PrimExpr> new_args = op->args;
173
174
175
      new_args.Set(is_1d_tma_load ? 2 : 1,
                   Call(DataType::Handle(), get_mbarrier(),
                        {IntImm(DataType::Int(32), 0)}));
176
177
178
179
180
181
      return Call(op->dtype, op->op, new_args);
    }
    return IRMutatorWithAnalyzer::VisitExpr_(op);
  }
};

182
class TmaBarrierCollector : public IRVisitorWithAnalyzer {
183
public:
184
185
186
  TmaBarrierCollector(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}

187
188
189
  Map<ObjectRef, PrimExpr> tma_op_to_barrier_id() {
    return tma_op_to_barrier_id_;
  }
190
  Map<PrimExpr, IntImm> barrier_id_to_range() { return barrier_id_to_range_; }
191
192

private:
193
  void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) {
194
195
196
197
198
199
200
201
202
203
    if (barrier_id_to_range_.count(barrier_id)) {
      auto old_extent = barrier_id_to_range_[barrier_id];
      ICHECK_EQ(old_extent->value, extent->value)
          << "barrier_id: " << barrier_id << " has different extent";
      barrier_id_to_range_.Set(barrier_id, extent);
    } else {
      barrier_id_to_range_.Set(barrier_id, extent);
    }
  }

204
205
  void VisitStmt_(const EvaluateNode *op) final {
    if (const auto *call = op->value.as<CallNode>()) {
206
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
207
        pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
208
      } else if (call->op.same_as(mbarrier_expect_tx())) {
209
        pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
210
211
      } else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
        PrimExpr barrier_id = call->args[0];
212
        for (const auto &tma_call : pending_tma_ops_) {
213
214
          tma_op_to_barrier_id_.Set(tma_call, barrier_id);
        }
215
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
216
217
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
218
        UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
219
        pending_tma_ops_.clear();
220
221
222
      } else if (call->op.same_as(builtin::ptx_wait_barrier())) {
        PrimExpr barrier_id = call->args[0];
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
223
224
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
225
        UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
226
227
228
229
230
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

231
232
233
234
235
236
237
238
239
240
241
  void VisitStmt_(const AttrStmtNode *op) {
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

  IterVar thread_var_;
242
243
  std::vector<Call> pending_tma_ops_;
  Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
244
  Map<PrimExpr, IntImm> barrier_id_to_range_;
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
  Map<Var, Buffer> buffer_data_to_buffer_;
};

class TmaSequenceCollector : public IRVisitorWithAnalyzer {
public:
  TmaSequenceCollector(Map<ObjectRef, PrimExpr> tma_op_to_barrier_id)
      : tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)) {}

  std::vector<bool> GetSequence() {
    std::vector<bool> clear_zero_list(expect_tx_count_, false);
    int zero_idx = -1;
    int zero_count = 0;

    for (auto v : sequence) {
      if (v == 0) {
        zero_count += 1;
        zero_idx += 1;
      } else {
        if (zero_count == 1) {
          clear_zero_list[zero_idx] = expect_[zero_idx] && !has_simt_copy_;
          if (clear_zero_list[zero_idx] == false) {
            int begin = int_sets_[zero_idx].min().as<IntImmNode>()->value;
            int end = int_sets_[zero_idx].max().as<IntImmNode>()->value;
            for (int i = begin; i <= end; ++i) {
              restore_barrier_ids_.push_back(i);
            }
          }
        } else {
          for (int i{zero_idx}; i > zero_idx - zero_count; --i) {
            int begin = int_sets_[i].min().as<IntImmNode>()->value;
            int end = int_sets_[i].max().as<IntImmNode>()->value;
            for (int i = begin; i <= end; ++i) {
              restore_barrier_ids_.push_back(i);
            }
          }
        }
        zero_count = 0;
      }
    }

    return clear_zero_list;
  }

  std::vector<int> GetRestoreBarrierIds() { return restore_barrier_ids_; }

  void VisitStmt_(const ForNode *op) final {
    var_int_set_.Set(op->loop_var,
                     arith::IntSet::FromMinExtent(op->min, op->extent));
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(mbarrier_expect_tx())) {
298
299
300
301
302
303
304
305
306
      auto call_ref = tvm::ffi::GetRef<Call>(op);
      if (tma_op_to_barrier_id_.count(call_ref)) {
        PrimExpr e = tma_op_to_barrier_id_[call_ref].as<CallNode>()->args[0];
        auto int_set = arith::EvalSet(e, var_int_set_);
        expect_.push_back(if_depth_ == 1);
        sequence.push_back(0);
        int_sets_.push_back(int_set);
        expect_tx_count_ += 1;
      }
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
      sequence.push_back(1);
    } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
      has_simt_copy_ = true;
    }
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    if_depth_ += 1;

    IRVisitorWithAnalyzer::VisitStmt(op->then_case);

    if (op->else_case) {
      IRVisitorWithAnalyzer::VisitStmt(op->else_case.value());
    }
    if_depth_ -= 1;
  }

  std::vector<int> sequence;
  int expect_tx_count_{0};
  std::vector<bool> expect_;
  bool has_simt_copy_{false};
  std::vector<int> restore_barrier_ids_;
  int if_depth_{0};
  Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
333
  arith::Analyzer *analyzer_{};
334
335
  Map<Var, arith::IntSet> var_int_set_;
  std::vector<arith::IntSet> int_sets_;
336
};
337
338
339
340

class BarrierCreationRewriter : public StmtExprMutator {
public:
  BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
341
342
343
                          PrimExpr producer_thread_extent,
                          int ensure_min_count = 0,
                          PrimExpr default_barrier_thread_count = 1)
344
      : restore_barrier_ids_(std::move(restore_barrier_ids)),
345
346
347
348
        producer_thread_extent_(std::move(producer_thread_extent)),
        ensure_min_count_(ensure_min_count),
        default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
  }
349
350
351

  PrimExpr VisitExpr_(const CallNode *op) {
    if (op->op.same_as(create_list_of_mbarrier())) {
352
353
354
355
356
357
358
      size_t cur_n = op->args.size();
      size_t need_n =
          std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));

      // Mark barriers to restore across the full needed length, not just the
      // original length, so newly appended entries can be restored as well.
      std::vector<bool> replace(need_n, false);
359
      for (auto &id : restore_barrier_ids_) {
360
361
362
        if (id >= 0 && static_cast<size_t>(id) < replace.size()) {
          replace[id] = true;
        }
363
364
      }

365
366
367
368
369
370
      Array<PrimExpr> new_args;
      new_args.reserve(need_n);

      // Preserve/override existing entries
      for (size_t i{0}; i < cur_n; ++i) {
        if (replace[i]) {
371
372
373
374
375
          new_args.push_back(producer_thread_extent_);
        } else {
          new_args.push_back(op->args[i]);
        }
      }
376
377
378
379
380
381
382
383
384
      // Append additional barriers if required
      for (size_t i = cur_n; i < need_n; ++i) {
        if (replace[i]) {
          new_args.push_back(producer_thread_extent_);
        } else {
          new_args.push_back(default_barrier_thread_count_);
        }
      }

385
386
387
388
389
      return Call(op->dtype, op->op, new_args);
    } else {
      return StmtExprMutator::VisitExpr_(op);
    }
  }
390
391

private:
392
393
  std::vector<int> restore_barrier_ids_;
  PrimExpr producer_thread_extent_;
394
395
  int ensure_min_count_{0};
  PrimExpr default_barrier_thread_count_{1};
396
397
};

398
399
400
401
// we trust mbarrier_wait_parity to be correct
class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
public:
  TmaBarrierRewriter(arith::Analyzer *analyzer,
402
403
404
                     Map<ObjectRef, PrimExpr> tma_op_to_barrier_id,
                     Map<PrimExpr, IntImm> barrier_id_to_range,
                     bool has_create_list_of_mbarrier)
405
      : IRMutatorWithAnalyzer(analyzer),
406
407
        tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)),
        barrier_id_to_range_(std::move(barrier_id_to_range)),
408
        has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {}
409
410

  static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
411
412
413
414
    auto buffer_lca = DetectBufferAccessLCA(f);
    Map<Var, Buffer> buffer_data_to_buffer_;
    for (auto [buffer, _] : buffer_lca)
      buffer_data_to_buffer_.Set(buffer->data, buffer);
415
    f = TmaExpectTxRewriter::Rewrite(f, analyzer);
416
    TmaBarrierCollector collector(buffer_data_to_buffer_);
417
    collector(f->body);
418
    bool has_create_list_of_mbarrier = false;
419
420
    PostOrderVisit(f->body, [&](const ObjectRef &node) {
      if (const auto *call = node.as<CallNode>()) {
421
422
        if (call->op.same_as(create_list_of_mbarrier())) {
          has_create_list_of_mbarrier = true;
423
424
        } else if (call->op.same_as(builtin::ptx_init_barrier_thread_count())) {
          has_create_list_of_mbarrier = true;
425
426
427
428
        }
      }
    });
    TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(),
429
430
                                collector.barrier_id_to_range(),
                                has_create_list_of_mbarrier);
431
    f.CopyOnWrite()->body = rewriter(f->body);
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    // Compute the minimum number of barriers actually referenced in the body
    // after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
    struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
      int max_idx{-1};
      void VisitExpr_(const CallNode *op) final {
        if (op->op.same_as(get_mbarrier())) {
          if (op->args.size() == 1) {
            if (const auto *imm = op->args[0].as<IntImmNode>()) {
              max_idx = std::max(max_idx, static_cast<int>(imm->value));
            }
          }
        }
        StmtExprVisitor::VisitExpr_(op);
      }
    };

    GetMbarrierMaxIdxCollector max_idx_collector;
    max_idx_collector(f->body);
    int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count

    // For simple TMA-only producers, default barrier arrive count should be 1
    // (only the elected leader performs the TMA arrive/expect).
454
    auto barrier_creation_rewriter = BarrierCreationRewriter(
455
456
        rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_,
        ensure_min_count, Integer(1));
457
    f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
458
459
460
461
    return f;
  }

private:
462
  Stmt VisitStmt_(const BlockNode *op) {
463
    auto block = tvm::ffi::GetRef<Block>(op);
464
    if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() &&
465
        op->name_hint == MainBlockName) {
466
467
468
469
470
      ICHECK(false) << "Please declare create_list_of_mbarrier.";
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
  Stmt VisitStmt_(const IfThenElseNode *op) {
    if (first_if) {
      if (op->condition.as<GENode>()) {
        producer_thread_extent_ =
            thread_var_->dom->extent - op->condition.as<GENode>()->b;
      }
      TmaSequenceCollector collector(tma_op_to_barrier_id_);
      collector(op->then_case);
      clear_expect_list_ = collector.GetSequence();
      restore_barrier_ids_ = collector.GetRestoreBarrierIds();
      first_if = false;

      is_producer_ = true;

      auto then_case = StmtExprMutator::VisitStmt(op->then_case);

      is_producer_ = false;
      Stmt else_case;
      if (op->else_case.defined())
        else_case = StmtExprMutator::VisitStmt(op->else_case.value());
      return IfThenElse(op->condition, then_case, else_case);
    }
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == "kWarpSpecializationScope") {
      has_warp_specialization_ = true;
      first_if = true;
    } else if (op->attr_key == tir::attr::thread_extent &&
               Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_var_ = Downcast<IterVar>(op->node);
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

507
  PrimExpr VisitExpr_(const CallNode *op) {
508
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
      auto call_ref = tvm::ffi::GetRef<Call>(op);
      if (!tma_op_to_barrier_id_.count(call_ref)) {
        // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
        // so codegen can emit mbarrier[index]. This handles degenerate
        // producer-only kernels where no arrive() is seen and mapping is empty.
        auto arg0 = op->args[0].as<Call>();
        bool is_1d_tma_load =
            arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
            !arg0.value()->op.same_as(create_tma_im2col_descriptor());
        if (is_1d_tma_load && op->args.size() >= 3) {
          if (const auto *imm = op->args[2].as<IntImmNode>()) {
            Array<PrimExpr> new_args = op->args;
            new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
                                 {IntImm(DataType::Int(32),
                                         static_cast<int>(imm->value))}));
            return Call(op->dtype, op->op, new_args);
          }
        }
        return IRMutatorWithAnalyzer::VisitExpr_(op);
      }
      auto barrier_id = tma_op_to_barrier_id_[call_ref];
530
      auto new_args = op->args;
531
532
      auto arg0 = op->args[0].as<Call>();
      auto is_1d_tma_load =
533
534
          arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
          !arg0.value()->op.same_as(create_tma_im2col_descriptor());
535
536
537
538
539
      if (is_1d_tma_load) {
        new_args.Set(2, barrier_id);
      } else {
        new_args.Set(1, barrier_id);
      }
540
541
      return Call(op->dtype, op->op, new_args);
    } else if (op->op.same_as(mbarrier_expect_tx())) {
542
543
544
545
546
      auto call_ref = tvm::ffi::GetRef<Call>(op);
      if (!tma_op_to_barrier_id_.count(call_ref)) {
        return IRMutatorWithAnalyzer::VisitExpr_(op);
      }
      auto barrier_id = tma_op_to_barrier_id_[call_ref];
547
548
      auto new_args = op->args;
      new_args.Set(0, barrier_id);
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
      if (!has_warp_specialization_)
        clear_arrive_ = false;
      else
        clear_arrive_ = clear_expect_list_[cur_expect_idx_++];
      if (clear_arrive_) {
        return Call(op->dtype, builtin::ptx_arrive_barrier_expect_tx(),
                    new_args);
      }
      return Call(op->dtype, op->op, new_args);
    } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
      if (clear_arrive_) {
        clear_arrive_ = false;
        return 0;
      }
      // by default, all threads must wait.
      auto new_args = op->args;
565
566
567
568
569
      return Call(op->dtype, op->op, new_args);
    }
    return IRMutatorWithAnalyzer::VisitExpr_(op);
  }
  Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
570
571
  Map<PrimExpr, IntImm> barrier_id_to_range_;
  bool has_create_list_of_mbarrier_;
572
573
574
575
576
577
578
  bool clear_arrive_{false};
  bool first_if{false}, has_warp_specialization_{false}, is_producer_{false};
  IterVar thread_var_;
  int tma_expect_tx_{0}, cur_expect_idx_{0};
  std::vector<bool> clear_expect_list_;
  std::vector<int> restore_barrier_ids_;
  PrimExpr producer_thread_extent_;
579
580
581
};

tvm::transform::Pass InjectTmaBarrier() {
582
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
583
584
585
586
587
588
589
590
591
    // Check if function only uses threadIdx.x before proceeding
    if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
      LOG(WARNING) << "InjectTmaBarrier will be disabled because the program "
                      "uses thread tags other than threadIdx.x\n"
                   << "If you want to use TMA barrier, please refactor "
                      "your program to use threadIdx.x only";
      // Return original function unchanged if other thread tags are found
      return f;
    }
592
593
594
595
596
597
    arith::Analyzer analyzer;
    return TmaBarrierRewriter::Rewrite(f, &analyzer);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
}

598
TVM_FFI_STATIC_INIT_BLOCK() {
599
600
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
601
}
602
603
604

} // namespace tl
} // namespace tvm