warp_specialized_rewriter.cc 52.7 KB
Newer Older
1
/*!
2
 * \file warp_specialized_rewriter.cc
3
4
5
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

6
#include "arith/ir_visitor_with_analyzer.h"
7
#include "tir/analysis/var_use_def_analysis.h"
8
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
13
14
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

15
16
#include <utility>

17
#include "../op/builtin.h"
18
#include "./common/collector.h"
19
20
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
21
22
23
24
25

namespace tvm {
namespace tl {

using namespace tir;
26
using namespace runtime;
27
using arith::IRVisitorWithAnalyzer;
28

29
30
31
32
33
34
struct LoopInfo {
  Var loop_var;
  PrimExpr extent;
  PrimExpr min;
};

35
enum class Role : uint8_t { kConsumer, kProducer, kBoth };
36

37
class ProducerBufferDetector : public StmtExprVisitor {
38
public:
39
40
  ProducerBufferDetector(
      std::unordered_set<const BufferNode *> cur_producer_buffers)
41
      : cur_producer_buffers_(std::move(cur_producer_buffers)) {}
42
43

  void clear() { has_producer_buffer_ = false; }
44
45

  void VisitExpr_(const CallNode *call) final {
46
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
47
      has_producer_buffer_ = true;
48
    }
49
    StmtExprVisitor::VisitExpr_(call);
50
51
  }

52
53
54
55
56
57
58
59
60
  void VisitExpr_(const BufferLoadNode *op) final {
    if (cur_producer_buffers_.count(op->buffer.get())) {
      has_producer_buffer_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_producer_buffer_ = false;
  std::unordered_set<const BufferNode *> cur_producer_buffers_;
61
62
63
64
};

class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
65
  auto FindProducerusedBuffer(const Stmt &stmt) {
66
67
68
69
70
71
72
73
74
75
    producer_buffers_.clear();
    std::unordered_set<const BufferNode *> last_producer_buffers_;
    for (;;) {
      VisitStmt(stmt);
      if (producer_buffers_ == last_producer_buffers_) {
        break;
      }
      last_producer_buffers_ = producer_buffers_;
    }
    return producer_buffers_;
76
77
78
79
80
81
82
  }

  void InsertBuffer(const PrimExpr &expr) {
    // Find the buffer that is used in the condition
    VarUseDefAnalyzer usage(Array<Var>{});
    usage(expr);
    for (const auto &buffer : usage.buffer_use_count_) {
83
      producer_buffers_.insert(buffer.first);
84
85
86
87
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
88
89
    ProducerBufferDetector producer_buffer_detector(producer_buffers_);
    producer_buffer_detector(op->then_case);
90
    if (op->else_case.defined()) {
91
      producer_buffer_detector(op->else_case.value());
92
    }
93
    if (producer_buffer_detector.has_producer_buffer_) {
94
95
96
97
98
99
      InsertBuffer(op->condition);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
100
101
102
    ProducerBufferDetector producer_buffer_detector(producer_buffers_);
    producer_buffer_detector(op->body);
    if (producer_buffer_detector.has_producer_buffer_) {
103
104
105
106
107
108
      InsertBuffer(op->min);
      InsertBuffer(op->extent);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

109
110
111
112
113
114
115
  void VisitStmt_(const BufferStoreNode *op) final {
    if (producer_buffers_.count(op->buffer.get())) {
      InsertBuffer(op->value);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

116
117
118
119
  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
      for (auto arg : op->args) {
        if (auto buffer_load = arg.as<BufferLoadNode>()) {
120
          producer_buffers_.insert(buffer_load->buffer.get());
121
122
123
124
125
        }
      }
    }
  }

126
private:
127
  std::unordered_set<const BufferNode *> producer_buffers_;
128
129
};

130
class WarpSpecializedRoleMarker : public StmtVisitor {
131
public:
132
  WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
133
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
134

135
136
  void Prepare(const Stmt &stmt) {
    ProducerUsedBufferFinder finder;
137
    producer_buffers_ = finder.FindProducerusedBuffer(stmt);
138
139
  }

140
  Role GetRole(const StmtNode *stmt) const {
141
142
143
144
145
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

146
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
147

148
  void VisitStmt_(const EvaluateNode *op) final {
149
150
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
151
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
152
153
154
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
155
156
157
      if (call->op.same_as(loop_break())) {
        role = Role::kBoth;
      }
158
159
160
161
    }
    SetRole(op, role);
  }

162
  void VisitStmt_(const BufferStoreNode *op) final {
163
164
    auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer->data));
    bool is_shared_store = scope.rank == StorageRank::kShared;
165
    if (producer_buffers_.count(op->buffer.get())) {
166
167
168
      SetRole(op, Role::kBoth);
      return;
    }
169
170
171
172
173
174
175
176
177
178
179
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
180
181
    if (reads.empty())
      role = Role::kConsumer;
182
183
184
185
186
187
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
188
189
    if (role == Role::kProducer)
      has_simt_copy_ = true;
190
191
192
    SetRole(op, role);
  }

193
  void VisitStmt_(const SeqStmtNode *op) final {
194
195
196
197
198
199
200
201
202
203
204
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

205
  void VisitStmt_(const IfThenElseNode *op) final {
206
207
208
209
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
210
211
      if (role != role_else)
        role = Role::kBoth;
212
213
214
215
    }
    SetRole(op, role);
  }

216
  void VisitStmt_(const BlockRealizeNode *op) final {
217
218
219
220
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

221
222
223
224
225
226
  void VisitStmt_(const AllocateNode *op) final {
    StmtVisitor::VisitStmt_(op);
    Role role = Role::kConsumer;
    SetRole(op, role);
  }

227
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
228
229
230
231
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

232
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
233
  void VisitStmt_(const WhileNode *op) final { HandleBodyStmt(op); }
234
235
236
237
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
238
239
240
241
242

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

243
244
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
245
  Map<Var, Buffer> buffer_data_to_buffer_;
246
  std::unordered_map<const StmtNode *, Role> map_;
247
248
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
249
  std::unordered_set<const BufferNode *> producer_buffers_;
250
251
252
};

static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
253
  return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)});
254
255
}

256
static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
257
258
                              const PrimExpr &pred = 1) {
  Array<PrimExpr> args = {makeGetBarrier(std::move(barrier_id))};
259
260
261
262
263
264
  if (cta_id != -1) {
    args.push_back(cta_id);
    args.push_back(pred);
  }
  return Evaluate(
      Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args));
265
266
267
}

static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
268
  auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
269
                   {makeGetBarrier(std::move(barrier_id))});
270
271
272
273
  return Evaluate(call);
}

static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
274
  auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
275
                   {makeGetBarrier(std::move(barrier_id)), std::move(parity)});
276
277
278
279
  return Evaluate(call);
}

class ProducerTraitsCollector : public StmtExprVisitor {
280
public:
281
282
  ProducerTraitsCollector() { Clear(); }

283
  void Clear() { has_simt_copy = false; }
284

285
  void Collect(const Stmt &stmt) { VisitStmt(stmt); }
286
287
288

  bool HasSimtCopy() { return has_simt_copy; }

289
private:
290
291
292
293
294
295
296
297
298
299
300
301
  void VisitStmt_(const IfThenElseNode *op) final {
    bool old_in_if_cond = in_if_cond_;
    in_if_cond_ = true;
    VisitExpr(op->condition);
    in_if_cond_ = old_in_if_cond;

    VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      VisitStmt(op->else_case.value());
    }
  }

302
  void VisitExpr_(const BufferLoadNode *op) final {
303
304
305
    if (!in_if_cond_) {
      has_simt_copy = true;
    }
306
307
308
    StmtExprVisitor::VisitExpr_(op);
  }

309
  bool has_simt_copy{};
310
  bool in_if_cond_ = false;
311
312
313
314
};

// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
315
public:
316
317
  static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
    MbarrierRewriter rewriter;
318
319
    rewriter.producer_barrier_idx_ = std::move(barrier_id);
    return rewriter(std::move(stmt));
320
321
  }

322
323
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
324
    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
325
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
326
327
328
329
330
331
332
333
334
335
336
337
338
      auto mbar = makeGetBarrier(producer_barrier_idx_);
      auto arg0 = call->args[0].as<Call>();
      // Check if this is a 1D TMA load
      auto is_1d_tma_load =
          arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
          call->op.same_as(tma_load());
      if (is_1d_tma_load) {
        call.CopyOnWrite()->args.Set(2, mbar);
      } else {
        Call access_ptr = Downcast<Call>(call->args[2]);
        ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
        call.CopyOnWrite()->args.Set(1, mbar);
      }
339
340
341
342
343
344
345
    }
    return call;
  }
  PrimExpr producer_barrier_idx_;
};

class ThreadIdxRewriter : public StmtExprMutator {
346
public:
347
348
349
  static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced,
                      PrimExpr thread_extent, bool do_shuffle = false) {
    auto rewriter =
350
351
352
        ThreadIdxRewriter(std::move(thread_var), std::move(replaced),
                          std::move(thread_extent), do_shuffle);
    return rewriter(std::move(stmt));
353
354
  }

355
private:
356
357
  ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent,
                    bool do_shuffle)
358
359
      : thread_var_(std::move(thread_var)), replaced_(std::move(replaced)),
        thread_extent_(std::move(thread_extent)), do_shuffle_(do_shuffle) {}
360

361
  PrimExpr VisitExpr_(const VarNode *var) final {
362
363
364
365
366
367
368
    if (var == thread_var_.get()) {
      return replaced_;
    } else {
      return StmtExprMutator::VisitExpr_(var);
    }
  }

369
370
371
372
373
374
375
376
377
378
379
380
381
  Stmt VisitStmt_(const IfThenElseNode *op) final {
    auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
      return parameter == thread_var_.get();
    };
    maybe_thread_opt_ = false;
    if (!op->else_case.defined() && op->condition.as<EQNode>() &&
        UsesVar(op->condition, f_uses_thread_index) &&
        !(UsesVar(op->then_case, f_uses_thread_index))) {
      auto eq_op = Downcast<EQ>(op->condition);
      if (eq_op->a.as<VarNode>() == thread_var_.get() ||
          eq_op->b.as<VarNode>() == thread_var_.get()) {
        maybe_thread_opt_ = true;
      }
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
      auto then_case = StmtExprMutator::VisitStmt(op->then_case);
      maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_;
      has_tma_op_ = false;
      if (maybe_thread_opt_) {
        return IfThenElse(
            Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
            StmtExprMutator::VisitStmt(op->then_case), std::nullopt);
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

  PrimExpr VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tl::tma_load()) ||
        op->op.same_as(tl::tma_load_im2col()) ||
        op->op.same_as(tl::tma_store())) {
      has_tma_op_ = true;
399
    }
400
    return StmtExprMutator::VisitExpr_(op);
401
402
  }

403
404
  Var thread_var_;
  PrimExpr replaced_;
405
406
407
  PrimExpr thread_extent_;
  bool maybe_thread_opt_ = false;
  bool do_shuffle_;
408
  bool has_tma_op_ = false;
409
410
};

411
412
413
414
415
416
Block MakeGroupBlock(const Stmt &stmt,
                     const Map<String, ObjectRef> &annotations) {
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ stmt,
              /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
              /*annotations=*/annotations);
417
418
419
420
  return block;
}

struct OpInfo {
421
  int group_size{}, order{}, stage{};
422
423
424
425
426
427
  std::vector<int> group;
};
struct PipelineInfo {
  std::vector<OpInfo> op_infos;

  PipelineInfo() = default;
428
429
430
  PipelineInfo(const Array<Array<Integer>> &group_info,
               const Array<Integer> &order_info,
               const Array<Integer> &stage_info) {
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    int n = static_cast<int>(group_info.size());
    ICHECK(n == static_cast<int>(order_info.size()));
    ICHECK(n == static_cast<int>(stage_info.size()));
    // int cur_id = 0;
    for (int i = 0; i < n; i++) {
      OpInfo op_info;
      op_info.group_size = group_info[i].size();
      for (int j = 0; j < op_info.group_size; j++) {
        op_info.group.push_back(group_info[i][j].as<IntImmNode>()->value);
      }
      op_info.order = order_info[i].as<IntImmNode>()->value;
      op_info.stage = stage_info[i].as<IntImmNode>()->value;
      op_infos.push_back(op_info);
    }
  }

447
  PipelineInfo(const PipelineInfo &other) {
448
    for (const auto &op_info : other.op_infos) {
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
      op_infos.push_back(op_info);
    }
  }

  std::pair<int, int> FindStmt(int stmt_idx) {
    for (size_t i = 0; i < op_infos.size(); i++) {
      for (size_t j = 0; j < op_infos[i].group.size(); j++) {
        if (op_infos[i].group[j] == stmt_idx) {
          return std::make_pair(i, j);
        }
      }
    }
    return std::make_pair(-1, -1);
  }

  void UpdateOrder(int order) {
    for (int i = 0; i < static_cast<int>(op_infos.size()); i++) {
      if (op_infos[i].order >= order && op_infos[i].order > 0) {
        op_infos[i].order++;
      }
    }
  }

  int SplitOp(int stmt_idx) {
    auto pair = FindStmt(stmt_idx);
    int op_idx = pair.first;
    int inner_idx = pair.second;
    ICHECK(op_idx != -1);
    ICHECK(inner_idx != -1);
    OpInfo half0;
    OpInfo half1;
    // The order to do sync
    int sync_order = op_infos[op_idx].order + 1;
    UpdateOrder(sync_order);

    half0.group_size = inner_idx + 1;
    half0.order = op_infos[op_idx].order;
    half0.stage = op_infos[op_idx].stage;
    for (int i = 0; i <= inner_idx; i++) {
      half0.group.push_back(op_infos[op_idx].group[i]);
    }
    half1.group_size = op_infos[op_idx].group_size - inner_idx - 1;
    half1.order = op_infos[op_idx].order + 2;
    half1.stage = op_infos[op_idx].stage;
    for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) {
      half1.group.push_back(op_infos[op_idx].group[i]);
    }
    op_infos.erase(op_infos.begin() + op_idx);
    if (half0.group_size > 0) {
      op_infos.insert(op_infos.begin() + op_idx, half0);
    }
    if (half1.group_size > 0) {
      UpdateOrder(half1.order);
      op_infos.insert(op_infos.begin() + op_idx + 1, half1);
    }
    return sync_order;
  }

  void PrintPipelineInfo() {
508
    std::cout << "Print op_infos:" << '\n';
509
    for (size_t i = 0; i < op_infos.size(); i++) {
510
      std::cout << i << " " << op_infos[i].group_size << " "
511
                << op_infos[i].order << " " << op_infos[i].stage << '\n';
512
    }
513
    std::cout << "End of print" << '\n';
514
515
516
517
  }
};

class GroupOpRewriter : public StmtExprMutator {
518
public:
519
520
  GroupOpRewriter(const PipelineInfo &pipeline_info)
      : pipeline_info_(pipeline_info) {}
521

522
523
private:
  Stmt VisitStmt_(const ForNode *op) final {
524
525
526
527
528
529
530
531
532
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));
    auto original_node = (op->body).as<SeqStmtNode>();
    if (!original_node) {
      return GetRef<For>(op);
    }
    Array<Stmt> new_body;
    int cur_id = 0;
    for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
533
534
      if (pipeline_info_.op_infos[i].group_size == 0)
        continue;
535
      Array<Stmt> block_stmt;
536
537
      for (int j = 0;
           j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
538
        // ICHECK(group_info_[i][j].as<IntImmNode>());
539
540
        // int index =
        // static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
541
542
543
544
545
546
        ICHECK(original_node->seq[cur_id].as<BlockNode>());
        auto block = original_node->seq[cur_id].as<BlockNode>();
        // TODO: handle nested seqstmt
        block_stmt.push_back(block->body);
        cur_id++;
      }
547
548
549
550
      new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                            ? block_stmt[0]
                                            : SeqStmt(std::move(block_stmt)),
                                        annotations));
551
552
553
    }
    Array<Integer> order_anno;
    Array<Integer> stage_anno;
554
    for (const auto &op_info : pipeline_info_.op_infos) {
555
556
557
      order_anno.push_back(Integer(op_info.order));
      stage_anno.push_back(Integer(op_info.stage));
    }
558
    Map<String, Any> for_annotations = op->annotations;
559
560
561
    for_annotations.erase("tl_pipeline_group");
    for_annotations.Set("software_pipeline_order", order_anno);
    for_annotations.Set("software_pipeline_stage", stage_anno);
562
563
564
565
    For new_for =
        For(op->loop_var, op->min, op->extent, op->kind,
            new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
            op->thread_binding, for_annotations);
566
567
568
569
570
    return new_for;
  }

  PipelineInfo pipeline_info_;
};
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595

class WgMMACollector : public StmtExprVisitor {
public:
  WgMMACollector() = default;

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tl_gemm()) || op->op.same_as(tl_gemm_sp())) {
      auto op_name = std::string(op->args[0].as<StringImmNode>()->value);
      if (has_wgmma_) {
        has_wgmma_ =
            op_name.find("false") == std::string::npos && !in_if_scope_;
      }
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    in_if_scope_ = true;
    StmtExprVisitor::VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      StmtExprVisitor::VisitStmt(op->else_case.value());
    }
    in_if_scope_ = false;
  }

596
  static bool HasWgMMA(const Stmt &stmt) {
597
598
599
600
601
602
603
604
605
    auto collector = WgMMACollector();
    collector(stmt);
    return collector.has_wgmma_;
  }

  bool has_wgmma_{true};
  bool in_if_scope_{false};
};

606
class WSCodeEmitter : public StmtMutator {
607
public:
608
  /**
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
   * @brief Construct a warp-specialized code emitter configured for producer or
   * consumer emission.
   *
   * Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered
   * code for a single warp-specialized block. The emitter is configured with
   * the loop/thread iteration variable, buffer mapping, role marker used to
   * classify statements, and two flags that control emission behavior:
   *
   * - `mbarrier_only`: when true, emission is restricted to barrier-related
   * operations only.
   * - `only_has_wgmma`: when true, the emitter will account for the presence of
   * WgMMA (workgroup MMA) operations when computing barrier/thread gating
   * behavior.
   *
   * @param is_emitting_producer True to emit producer-side groups; false to
   * emit consumer-side groups.
   * @param thread_iv IterVar representing the thread iteration variable
   * (threadIdx.*) whose Var is used for thread-index rewrites and gating.
   * @param buffer_data_to_buffer Map from buffer data Var to the corresponding
   * Buffer (used to resolve buffer references during emission).
   * @param marker Role marker that classifies statements as
   * producer/consumer/both; used to filter which statements are emitted on this
   * path.
   * @param mbarrier_only If true, restrict emission to mbarrier-related
   * statements and helpers.
   * @param only_has_wgmma If true, adjust emission and barrier-thread-count
   * logic for blocks that contain WgMMA operations.
   */
637
  WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv,
638
                Map<Var, Buffer> buffer_data_to_buffer,
639
                const WarpSpecializedRoleMarker &marker,
640
                bool mbarrier_only = false, bool only_has_wgmma = false)
641
      : is_emitting_producer_(is_emitting_producer),
642
643
644
        buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
        marker_(marker), thread_var_(thread_iv->var),
        mbarrier_only_(mbarrier_only), only_has_wgmma_(only_has_wgmma) {}
645

646
  /**
647
648
649
650
651
652
653
654
655
   * @brief Whether a SIMT-style bulk copy was detected.
   *
   * Returns true when a simulated SIMT (thread-parallel) copy pattern was
   * observed during analysis/emission, which can affect barrier insertion and
   * copy emission.
   *
   * @return true if a SIMT copy was detected; false otherwise.
   */
  bool hasSimtCopy() const { return has_simt_copy_; }
656

657
658
659
660
661
662
663
664
665
666
  /**
   * @brief Whether this emitter contains only warp-group MMA (WgMMA)
   * operations.
   *
   * Returns true if the emitter detected exclusively WgMMA usage in the region
   * it analyzed.
   *
   * @return bool true when only WgMMA-based code paths are present; false
   * otherwise.
   */
667
668
  bool onlyHasWgMMA() const { return only_has_wgmma_; }

669
private:
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
  template <
      typename NodeType> /**
                          * @brief Filter a statement by its producer/consumer
                          * role for emission.
                          *
                          * Returns one of:
                          * - the original statement (unchanged) when this
                          * emitter should emit it,
                          * - the result of visiting the statement (to descend
                          * into it) when mbarrier-only mode requires full
                          * traversal for non-producer roles,
                          * - an empty evaluate (`Evaluate(0)`) when the
                          * statement should be omitted.
                          *
                          * The decision is based on the role of `op` as
                          * reported by `marker_`, the emitter mode
                          * (`is_emitting_producer_`), and the `mbarrier_only_`
                          * flag.
                          *
                          * @param op The statement node to filter; its role is
                          * queried via `marker_`.
                          * @return Stmt The statement to place into the emitted
                          * IR (possibly transformed or an empty evaluate).
                          */
  Stmt FilterByRole(const NodeType *op) {
695
    Role role = marker_.GetRole(op);
696
697
698
699
700
    if (mbarrier_only_) {
      if (role != Role::kProducer)
        return StmtMutator::VisitStmt_(op);
    }
    if (role == Role::kBoth) {
701
      return StmtMutator::VisitStmt_(op);
702
    } else if ((role == Role::kProducer) == is_emitting_producer_) {
703
      return GetRef<Stmt>(op);
704
    } else {
705
      return Evaluate(0);
706
    }
707
708
  }

709
  /**
710
711
   * @brief Visit and transform a SeqStmt node, emitting grouped blocks with
   * barrier synchronization according to producer/consumer roles.
712
713
   *
   * This method examines the sequence to determine whether producer-side
714
715
   * synchronization is required (based on marker_ roles). If no producer sync
   * is needed it delegates to FilterByRole. Otherwise it:
716
717
718
719
   * - Recursively visits and transforms each child statement.
   * - Extracts an acquire/release sync pattern for the sequence via
   *   ExtractSyncPattern.
   * - For producer emission (is_emitting_producer_ == true):
720
721
   *   - Skips consumer-only statements unless marker_ marks a statement as
   * Both, in which case the statement is emitted as its own group.
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
   *   - For each statement, inserts parity waits for acquire patterns, rewrites
   *     release statements with MbarrierRewriter using a computed barrier id,
   *     collects SimT-copy presence (setting has_simt_copy_ and inserting
   *     cp.async barriers when found), optionally emits arrive barriers for
   *     release-after events, and emits each resulting set of statements as a
   *     group block annotated with "stmt_group".
   * - For consumer emission (is_emitting_producer_ == false):
   *   - Skips producer-only statements.
   *   - Inserts parity waits for acquire patterns, appends the transformed
   *     statement, and emits arrive barriers for release-after events. When
   *     only_has_wgmma_ is set, the arrive barrier uses a per-thread predicate
   *     (FloorMod(thread_var_,128)==0) with CTA=0; otherwise a full arrive is
   *     emitted.
   *   - Recomputes pipeline_info_ to drop producer-only ops.
   *
   * Side effects / state updates:
   * - Increments num_barriers_ by (number of extracted patterns * num_stages_).
   * - May set has_simt_copy_ when a SimT copy is detected in producer rewrites.
   * - Inserts barrier ids into released_barrier_ for release-after events.
   * - Updates pipeline_info_ for the consumer path to remove producer ops.
   *
   * The resulting statements are emitted as grouped blocks (via MakeGroupBlock)
   * with the annotation "stmt_group" and returned as either a single Stmt (if
   * there's only one group) or a SeqStmt containing the grouped blocks.
   *
   * @return Stmt The transformed statement (either a single group block or a
   * SeqStmt of group blocks).
   */
750
  Stmt VisitStmt_(const SeqStmtNode *op) final {
751

752
753
754
755
756
757
758
    bool has_producer = false;
    for (auto stmt : op->seq) {
      if (marker_.GetRole(stmt) == Role::kProducer) {
        has_producer = true;
        break;
      }
    }
759
760
761
762
    bool need_producer_sync =
        has_producer && marker_.GetRole(op) == Role::kBoth;
    if (!need_producer_sync)
      return FilterByRole(op);
763

764
    auto seq_transformed =
765
        op->seq.Map([&](const Stmt &stmt) { return VisitStmt(stmt); });
766
767

    auto map = ExtractSyncPattern(op->seq);
768

769
770
771
772
773
774
775
776
777
778
779
780
781
782
    /*
      std::cout << "Print ExtractSyncPattern" << std::endl;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
        << map.release_after[i] << std::endl;
      }
      std::cout << "Print sync pattern" << std::endl;
      for (auto pattern : map.patterns) {
        std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
        std::endl;
      }
      std::cout << "End of ExtractSyncPattern" << std::endl;
      pipeline_info_.PrintPipelineInfo();
    */
783
784
785
786
    Array<Stmt> new_body;
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));

787
    if (is_emitting_producer_) { // producer case
788
789
790
      ProducerTraitsCollector collector;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
791
792
793
794
795
796
797
798
799
800
801
        if (!mbarrier_only_) {
          if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
            continue;
          if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
            block_stmt.push_back(seq_transformed[i]);
            new_body.push_back(MakeGroupBlock(
                block_stmt.size() == 1 ? block_stmt[0]
                                       : SeqStmt(std::move(block_stmt)),
                annotations));
            continue;
          }
802
        }
803

804
        for (int pattern_idx : map.acquire[i]) {
805
          PrimExpr acquire_barrier_id =
806
807
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
808
809
                                ? bitwise_xor(parity_, 1)
                                : parity_;
810
811
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
812
        ICHECK(!map.release[i].empty());
813
814
815
816
817
818
819
        for (size_t j = 0; j < map.release[i].size(); j++) {
          int pattern_idx = map.release[i][j];
          PrimExpr release_barrier_id =
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          auto stmt =
              MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
          collector.Collect(stmt);
820
          block_stmt.push_back(stmt);
821
          if (collector.HasSimtCopy()) {
822
            block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
823
            has_simt_copy_ = true;
824
          }
825
826
827
828
829
830
831
832
833
834
835
836
          if (map.release_after[i][j]) {
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
          }
          collector.Clear();
          new_body.push_back(MakeGroupBlock(
              block_stmt.size() == 1 ? block_stmt[0]
                                     : SeqStmt(std::move(block_stmt)),
              annotations));
837
838
        }
      }
839
    } else { // consumer case
840
841
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
842
843
        if (marker_.GetRole(op->seq[i]) == Role::kProducer)
          continue;
844
        for (int pattern_idx : map.acquire[i]) {
845
          PrimExpr acquire_barrier_id =
846
847
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
848
849
                                ? bitwise_xor(parity_, 1)
                                : parity_;
850
851
852
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
        block_stmt.push_back(seq_transformed[i]);
853
854
855
856
857
        for (size_t j = 0; j < map.release[i].size(); j++) {
          if (map.release_after[i][j]) {
            int pattern_idx = map.release[i][j];
            PrimExpr release_barrier_id =
                stage_ + num_barriers_ + num_stages_ * pattern_idx;
858
859
860
861
862
            if (only_has_wgmma_)
              block_stmt.push_back(makeArriveBarrier(
                  release_barrier_id, 0, EQ(FloorMod(thread_var_, 128), 0)));
            else
              block_stmt.push_back(makeArriveBarrier(release_barrier_id));
863
864
865
866
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
867
868
          }
        }
869
870
871
872
        new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                              ? block_stmt[0]
                                              : SeqStmt(std::move(block_stmt)),
                                          annotations));
873
874
875
876
      }
      // Filter out the producer stmts
      int cur_id = 0;
      PipelineInfo new_pipeline_info;
877
878
      for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
           i++) {
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        auto op_info = pipeline_info_.op_infos[i];
        bool is_producer = false;
        for (int j = 0; j < op_info.group_size; j++) {
          if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) {
            is_producer = true;
          }
          cur_id++;
        }
        if (is_producer) {
          ICHECK(op_info.group_size == 1);
        } else {
          new_pipeline_info.op_infos.push_back(op_info);
        }
      }
      pipeline_info_ = new_pipeline_info;
    }

    num_barriers_ += map.patterns.size() * num_stages_;

898
    ICHECK(!new_body.empty());
899
900
901
    return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
  }

902
  Stmt VisitStmt_(const ForNode *op) final {
903
904
    int num_stages = 1;
    auto num_stages_anno = op->annotations.Get("num_stages");
905
906
907
    if (num_stages_anno) {
      ICHECK(num_stages_anno->as<IntImmNode>());
      num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
908
909
      ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
    }
910
    loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});
911
912
913
914

    Array<Array<Integer>> group_info_array;
    Array<Integer> order_info_array;
    Array<Integer> stage_info_array;
915

916
    auto group_anno = op->annotations.Get("tl_pipeline_group");
917
918
    if (group_anno) {
      group_info_array = Downcast<Array<Array<Integer>>>(group_anno.value());
919
920
    }
    auto order_anno = op->annotations.Get("tl_pipeline_order");
921
922
    if (order_anno) {
      order_info_array = Downcast<Array<Integer>>(order_anno.value());
923
924
    }
    auto stage_anno = op->annotations.Get("tl_pipeline_stage");
925
926
    if (stage_anno) {
      stage_info_array = Downcast<Array<Integer>>(stage_anno.value());
927
928
    }

929
930
    PipelineInfo pipeline_info(group_info_array, order_info_array,
                               stage_info_array);
931
932
    if (!pipeline_info.op_infos.empty()) {
      ICHECK(pipeline_info_.op_infos.empty())
933
          << "Nested pipeline not supported.";
934
935
936
937
938
939
940
941
942
    }

    PrimExpr parity_before = std::move(parity_);
    PrimExpr stage_before = std::move(stage_);
    int num_stages_before = num_stages_;
    PipelineInfo pipeline_info_before = pipeline_info_;

    num_stages_ = num_stages;
    pipeline_info_ = pipeline_info;
943
    PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
944
    for (size_t i = 1; i < loop_stack_.size(); ++i) {
945
946
      linear_index = linear_index * loop_stack_[i].extent +
                     (loop_stack_[i].loop_var - loop_stack_[i].min);
947
948
949
950
    }
    stage_ = FloorMod(linear_index, num_stages);
    parity_ = FloorMod(
        parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
951
952
953
    auto result = FilterByRole(op);

    Stmt grouped_for_node;
954
    if (result.as<ForNode>() && group_anno && !group_info_array.empty() &&
955
        !is_emitting_producer_) {
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
      GroupOpRewriter group_op_rewriter(pipeline_info_);
      auto for_node = Downcast<For>(result);
      grouped_for_node = group_op_rewriter(for_node);
    }

    parity_ = std::move(parity_before);
    stage_ = std::move(stage_before);
    num_stages_ = num_stages_before;
    pipeline_info_ = pipeline_info_before;

    // remove pipeline annotation
    auto for_node = result.as<For>();
    if (result.as<ForNode>()) {
      auto for_node = Downcast<For>(result);
      for_node.CopyOnWrite()->annotations.erase("num_stages");
971
      if (is_emitting_producer_ || group_info_array.empty()) {
972
973
974
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
      }
975
      if (is_emitting_producer_ || !group_anno || group_info_array.empty()) {
976
        loop_stack_.pop_back();
977
978
        return for_node;
      }
979
      loop_stack_.pop_back();
980
981
      return grouped_for_node;
    }
982
    loop_stack_.pop_back();
983
984
985
    return result;
  }

986
987
988
989
990
991
992
  Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BlockNode *op) final {
993
994
995
    ICHECK(0);
    return Stmt();
  }
996
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
997
998
999
1000
1001
1002
1003
1004
1005
    ICHECK(0);
    return Stmt();
  }

  struct SyncPattern {
    int release_idx, acquire_idx;
  };

  struct SyncPatternMap {
1006
1007
1008
    std::vector<std::vector<int>> acquire;
    std::vector<std::vector<int>> release;
    std::vector<std::vector<bool>> release_after;
1009
    std::vector<SyncPattern> patterns;
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019

    void resize(size_t n) {
      acquire.resize(n);
      release.resize(n);
      release_after.resize(n);
    }

    bool is_loop_dependency(int pattern_idx) {
      return patterns[pattern_idx].release_idx >
             patterns[pattern_idx].acquire_idx;
1020
1021
1022
    }
  };

1023
  std::vector<SyncPattern>
1024
  CreateBaseSyncPairs(const Array<Stmt> &seq_stmt,
1025
                      const std::vector<bool> &is_producer) {
1026
    const int n = seq_stmt.size();
1027
    std::vector<std::set<const BufferNode *>> reads, writes;
1028
1029
1030
    reads.reserve(n);
    writes.reserve(n);
    for (int i = 0; i < n; i++) {
1031
1032
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"",
1033
1034
                  /*body*/ seq_stmt[i]);
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
1035
      std::set<const BufferNode *> read_set, write_set;
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
      for (auto region : access[0]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          read_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          read_set.insert(region->buffer.get());
        }
      }
      for (auto region : access[1]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          write_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          write_set.insert(region->buffer.get());
        }
      }
1052
1053
1054
1055
      reads.push_back(std::move(read_set));
      writes.push_back(std::move(write_set));
    }

1056
1057
    auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
                           const std::set<const BufferNode *> &rhs) {
1058
      for (auto ptr : lhs)
1059
1060
        if (rhs.count(ptr))
          return true;
1061
1062
1063
1064
1065
1066
1067
1068
1069
      return false;
    };

    std::vector<SyncPattern> sync_patterns;
    // producer_release consumer_acquire,
    // inject before the first consumer stmt for each producer
    for (int i = 0; i < n; i++) {
      for (int j = i + 1; j < n; j++) {
        if (is_producer[i] != is_producer[j] &&
1070
1071
            (intersect_fn(writes[i], reads[j]) ||
             intersect_fn(reads[i], writes[j]))) {
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
          sync_patterns.push_back({i, j});
          break;
        }
      }
    }

    // consumer_release producer_acquire
    // valid when is_loop is true
    // inject before the earliest producer stmt for each consumer
    bool in_loop = !is_zero(parity_);
    if (in_loop) {
      for (int i = 0; i < n; i++) {
        for (int j = 0; j < i; j++) {
          if (is_producer[i] != is_producer[j] &&
1086
1087
              (intersect_fn(writes[i], reads[j]) ||
               intersect_fn(reads[i], writes[j]))) {
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
            sync_patterns.push_back({i, j});
            break;
          }
        }
      }
    }

    return sync_patterns;
  }

1098
1099
1100
  static std::vector<SyncPattern>
  RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
                           const std::vector<bool> &is_producer) {
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    /*
      Simplify multiple release-acquire pairs into one
      ------------------
        Produce(A)
        Produce(B)
        Consume(A, B)
      ------------------
      [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)]

      Or
      ------------------
        Produce(A, B)
        Consume(A)
        Consume(B)
      ------------------
      [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)]
    */
    int M = sync_patterns.size();
    std::vector<bool> removed(M, false);
    for (int i = 0; i < M; i++) {
      for (int j = 0; j < M; j++) {
        if (is_producer[sync_patterns[i].acquire_idx] ==
                is_producer[sync_patterns[j].acquire_idx] &&
            sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx &&
            sync_patterns[i].release_idx < sync_patterns[j].release_idx)
          removed[i] = true;
      }
    }

    std::vector<SyncPattern> sync_pattern_cleaned;
    sync_pattern_cleaned.reserve(M);
    for (int i = 0; i < M; i++)
1133
1134
      if (!removed[i])
        sync_pattern_cleaned.push_back(sync_patterns[i]);
1135
1136
1137
1138

    return sync_pattern_cleaned;
  }

1139
  SyncPatternMap ExtractSyncPattern(const Array<Stmt> &seq_stmt) {
1140
1141
1142
1143
1144
1145
1146
1147
    size_t num_stmts = seq_stmt.size();
    std::vector<bool> is_producer;
    is_producer.reserve(num_stmts);
    for (auto stmt : seq_stmt) {
      is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer);
    }

    auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
1148
1149
    auto sync_patterns =
        RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
1150
1151

    // for (auto pattern : sync_patterns) {
1152
1153
    //   std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
    //   std::endl;
1154
1155
1156
    // }

    SyncPatternMap map;
1157
    map.resize(num_stmts);
1158
    map.patterns = sync_patterns;
1159

1160
    for (size_t i = 0; i < sync_patterns.size(); i++) {
1161
1162
1163
1164
1165
1166
      int acquire_idx = sync_patterns[i].acquire_idx;
      int release_idx = sync_patterns[i].release_idx;

      map.acquire[acquire_idx].push_back(i);
      map.release[release_idx].push_back(i);
      map.release_after[release_idx].push_back(true);
1167
1168
    }

1169
    std::vector<int> cur_consumer_barrier, cur_producer_barrier;
1170
1171
    for (int i = num_stmts - 1; i >= 0; i--) {
      if (is_producer[i]) {
1172
        if (map.release[i].empty()) {
1173
1174
1175
1176
          for (auto pattern_idx : cur_producer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
1177
        } else {
1178
1179
1180
          for (auto pattern_idx : map.release[i]) {
            cur_producer_barrier.push_back(pattern_idx);
          }
1181
1182
        }
      } else {
1183
        if (map.release[i].empty()) {
1184
1185
1186
1187
          for (auto pattern_idx : cur_consumer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
1188
        } else {
1189
1190
1191
          for (auto pattern_idx : map.release[i]) {
            cur_consumer_barrier.push_back(pattern_idx);
          }
1192
1193
1194
1195
1196
1197
1198
1199
1200
        }
      }
    }
    return map;
  }

  const bool is_emitting_producer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  std::unordered_set<int> released_barrier_;
1201
  const WarpSpecializedRoleMarker &marker_;
1202
1203
1204
1205
1206

  int num_barriers_ = 0;
  PrimExpr parity_ = 0;
  PrimExpr stage_ = 0;
  int num_stages_ = 1;
1207
  std::vector<LoopInfo> loop_stack_;
1208
  Var thread_var_;
1209
  bool mbarrier_only_ = false;
1210
1211
  PipelineInfo pipeline_info_;
  friend class WarpSpecializedRewriter;
1212
1213
  bool only_has_wgmma_ = false;
  bool has_simt_copy_ = false;
1214
1215
1216
};

class WarpSpecializedRewriter : public StmtExprMutator {
1217
public:
1218
1219
1220
1221
1222
1223
  WarpSpecializedRewriter(bool disable_warp_specialized,
                          bool disable_shuffle_elect)
      : disable_warp_specialized_(disable_warp_specialized),
        disable_shuffle_elect_(disable_shuffle_elect) {}
  static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized,
                             bool disable_shuffle_elect) {
1224
1225
1226
    // Check if function only uses threadIdx.x before proceeding
    if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
      LOG(WARNING) << "WarpSpecialize will be disabled because the program "
1227
                      "uses thread tags other than threadIdx.x."
1228
1229
1230
1231
1232
1233
                   << "If you want to use warp specialization, please refactor "
                      "your program to use threadIdx.x only";
      // Return original function unchanged if other thread tags are found
      return f;
    }

1234
1235
    auto T = WarpSpecializedRewriter(disable_warp_specialized,
                                     disable_shuffle_elect);
1236
    T.buffer_lca_ = DetectBufferAccessLCA(f);
1237
1238
    for (auto [buffer, _] : T.buffer_lca_)
      T.buffer_data_to_buffer_.Set(buffer->data, buffer);
1239
1240
1241
1242
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

1243
1244
private:
  Stmt VisitStmt_(const AttrStmtNode *op) final {
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

1262
1263
1264
1265
  // If users define a thread binding, we will replace the thread binding with
  // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
  // the same as the thread extent
  Stmt VisitStmt_(const ForNode *op) final {
1266
1267
1268
1269
1270
1271
1272
    ICHECK(thread_iv_.defined());
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (for_node->kind == ForKind::kThreadBinding) {
      ICHECK(for_node->thread_binding.defined());
      String thread_tag = for_node->thread_binding.value()->thread_tag;
      ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
      Var thread_iv = Downcast<Var>(for_node->loop_var);
1273
      Stmt new_body =
1274
          ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_, 0);
1275
1276
1277
1278
1279
      return new_body;
    }
    return for_node;
  }

1280
  /**
1281
1282
   * @brief Rewrite a BlockRealize for warp specialization, inserting barriers
   * and emitting producer/consumer bodies.
1283
1284
1285
1286
1287
   *
   * This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_)
   * is defined and warp-specialization is applicable. It:
   * - Determines producer/consumer roles via WarpSpecializedRoleMarker and
   *   returns the original block if no producer is detected.
1288
1289
   * - If warp specialization is disabled, emits only mbarrier initialization
   * and the mbarrier-only transformed body.
1290
1291
1292
   * - Otherwise, detects WgMMA usage for the block body and constructs separate
   *   WSCodeEmitter instances for producer and consumer paths (propagating the
   *   WgMMA flag to the consumer emitter).
1293
1294
1295
   * - Generates producer/consumer code, applies register hint calls
   * (set_max_nreg) when available, and rewrites thread indices with
   * ThreadIdxRewriter to partition threads between producer and consumer roles.
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
   * - Computes and initializes a list of mbarrier handles with per-barrier
   *   arrive thread counts (taking SIMT-copy and WgMMA cases into account).
   * - Wraps the transformed body in an IfThenElse that dispatches producer vs
   *   consumer based on thread index, and annotates the region with the
   *   "kWarpSpecializationScope" attribute that contains producer/consumer
   *   thread extents.
   *
   * Side effects:
   * - May update member state: only_has_wgmma_, updated_thread_extent_,
   *   need_update_thread_extent_.
   * - May abort via ICHECK if invariants (e.g., matching barrier counts) are
   *   violated.
   *
   * @return The possibly rewritten BlockRealize statement (original when no
   *         warp-specialization is applied or thread_iv_ is undefined).
   */
1312
1313
1314
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
1315
1316
1317
1318
1319
1320
    if (!thread_iv_.defined()) {
      return block_realize;
    }

    Block block = block_realize->block;
    WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
1321
    marker.Prepare(block);
1322
1323
1324
1325
1326
1327
    marker(block);
    if (!marker.HasProducer()) {
      // Cannot detect any producer here, directly return.
      return block_realize;
    }

1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
    if (disable_warp_specialized_) {
      WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
                                     marker, true);
      auto code = mbarrier_emitter(block->body);
      int num_barriers = mbarrier_emitter.num_barriers_;
      Array<PrimExpr> barrier_num_threads;
      barrier_num_threads.reserve(num_barriers);
      PrimExpr arrive_thread_count = thread_iv_->dom->extent;
      for (int i = 0; i < num_barriers; i++) {
        barrier_num_threads.push_back(arrive_thread_count);
      }
      Stmt init_barrier = Evaluate(Call(
1340
          DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1341
1342
1343
1344
      block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
      block_realize.CopyOnWrite()->block = block;
      return block_realize;
    }
1345
    only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
1346
    WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
1347
1348
    WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
                           false, only_has_wgmma_);
1349
1350
1351
1352
1353
    Stmt producer_code = producer(block->body);
    Stmt consumer_code = consumer(block->body);
    PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
    PrimExpr producer_thread_extent = thread_iv_->dom->extent;
    // Need one warp-group for bulk-copy only case
1354
1355
    if (!marker.HasSimtCopy())
      producer_thread_extent = 128;
1356
1357

    updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
1358
1359
1360
1361
1362
1363
1364
1365

    producer_code = ThreadIdxRewriter::Rewrite(
        producer_code, thread_iv_->var,
        thread_iv_->var - consumer_thread_extent, producer_thread_extent,
        !disable_shuffle_elect_);
    consumer_code = ThreadIdxRewriter::Rewrite(
        consumer_code, thread_iv_->var, thread_iv_->var, consumer_thread_extent,
        !disable_shuffle_elect_);
1366
1367
1368
1369
1370
1371
1372
1373
    need_update_thread_extent_ = true;

    ICHECK(producer.num_barriers_ == consumer.num_barriers_)
        << producer.num_barriers_ << " " << consumer.num_barriers_;
    int num_barriers = consumer.num_barriers_;
    Array<PrimExpr> barrier_num_threads;
    barrier_num_threads.reserve(num_barriers);
    for (int i = 0; i < num_barriers; i++) {
1374
1375
1376
      PrimExpr arrive_thread_count =
          producer.released_barrier_.count(i)
              ? (producer.hasSimtCopy() ? producer_thread_extent : 1)
1377
1378
              : (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
                                 : consumer_thread_extent);
1379
1380
1381
      barrier_num_threads.push_back(arrive_thread_count);
    }

1382
    Stmt init_barrier = Evaluate(Call(
1383
        DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1384
1385
    Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
                           producer_code, consumer_code);
1386
    // Add an attr here to handle the partial thread count in ThreadSync pass.
1387
1388
    Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
                                  Downcast<IntImm>(consumer_thread_extent)};
1389
    body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body);
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403

    block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

  WarpSpecializedRewriter() = default;

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
1404
  bool disable_warp_specialized_ = false;
1405
  bool disable_shuffle_elect_ = false;
1406
  bool only_has_wgmma_ = false;
1407
1408
};

1409
1410
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
1411
  // return true means this aws will be disabled
1412
  static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
1413
1414
    WarpSpecializedDetector detector;
    detector.VisitStmt(stmt);
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    if (detector.has_warp_specialization_) {
      LOG(WARNING) << "Auto warp specialization will be disabled because warp "
                      "specialization is manually enabled";
      return true;
    }
    if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
      LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
                      "and mbarrier are both present";
      return true;
    }
    return false;
1426
1427
1428
1429
1430
  }

  WarpSpecializedDetector() {
    has_tma_op_ = false;
    has_mbarrier_op_ = false;
1431
    has_warp_specialization_ = false;
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(create_list_of_mbarrier()) ||
          call->op.same_as(mbarrier_wait_parity()) ||
          call->op.same_as(builtin::ptx_arrive_barrier()) ||
          call->op.same_as(builtin::ptx_cp_async_barrier())) {
        has_mbarrier_op_ = true;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
        op->op.same_as(set_max_nreg())) {
      has_tma_op_ = true;
    }
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

1455
  void VisitStmt_(const AttrStmtNode *op) final {
1456
1457
1458
1459
    if (op->attr_key == "warp_specialize" &&
        op->value.as<IntImmNode>()->value == 1) {
      has_warp_specialization_ = true;
    }
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

1470
  bool has_tma_op_{false};
1471
  IterVar thread_var_;
1472
  bool has_mbarrier_op_{false};
1473
  bool has_warp_specialization_{false};
1474
1475
};

1476
1477
1478
using namespace tir::transform;

tvm::transform::Pass WarpSpecialized() {
1479
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
1480
1481
    bool disable_warp_specialized =
        ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
1482
1483
    bool disable_shuffle_elect =
        ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
1484
1485
1486
    bool warp_specialized = WarpSpecializedDetector::Detect(f->body);

    if (!warp_specialized) {
1487
1488
      return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
                                                 disable_shuffle_elect);
1489
1490
    }
    return f;
1491
1492
1493
1494
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}

1495
1496
1497
1498
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
1499

1500
1501
} // namespace tl
} // namespace tvm