warp_specialized_rewriter.cc 50 KB
Newer Older
1
/*!
2
 * \file warp_specialized_rewriter.cc
3
4
5
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

6
#include "arith/ir_visitor_with_analyzer.h"
7
#include "tir/analysis/var_use_def_analysis.h"
8
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
13
14
15
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
16
#include "./common/collector.h"
17
18
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
19
20
21
22
23

namespace tvm {
namespace tl {

using namespace tir;
24
using namespace runtime;
25
using arith::IRVisitorWithAnalyzer;
26

27
28
29
30
31
32
struct LoopInfo {
  Var loop_var;
  PrimExpr extent;
  PrimExpr min;
};

33
34
enum class Role { kConsumer, kProducer, kBoth };

35
class ProducerBufferDetector : public StmtExprVisitor {
36
public:
37
38
39
40
41
  ProducerBufferDetector(
      std::unordered_set<const BufferNode *> cur_producer_buffers)
      : cur_producer_buffers_(cur_producer_buffers) {}

  void clear() { has_producer_buffer_ = false; }
42
43

  void VisitExpr_(const CallNode *call) final {
44
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
45
      has_producer_buffer_ = true;
46
    }
47
    StmtExprVisitor::VisitExpr_(call);
48
49
  }

50
51
52
53
54
55
56
57
58
  void VisitExpr_(const BufferLoadNode *op) final {
    if (cur_producer_buffers_.count(op->buffer.get())) {
      has_producer_buffer_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_producer_buffer_ = false;
  std::unordered_set<const BufferNode *> cur_producer_buffers_;
59
60
61
62
63
};

class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
  auto FindProducerusedBuffer(Stmt stmt) {
64
65
66
67
68
69
70
71
72
73
    producer_buffers_.clear();
    std::unordered_set<const BufferNode *> last_producer_buffers_;
    for (;;) {
      VisitStmt(stmt);
      if (producer_buffers_ == last_producer_buffers_) {
        break;
      }
      last_producer_buffers_ = producer_buffers_;
    }
    return producer_buffers_;
74
75
76
77
78
79
80
  }

  void InsertBuffer(const PrimExpr &expr) {
    // Find the buffer that is used in the condition
    VarUseDefAnalyzer usage(Array<Var>{});
    usage(expr);
    for (const auto &buffer : usage.buffer_use_count_) {
81
      producer_buffers_.insert(buffer.first);
82
83
84
85
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
86
87
    ProducerBufferDetector producer_buffer_detector(producer_buffers_);
    producer_buffer_detector(op->then_case);
88
    if (op->else_case.defined()) {
89
      producer_buffer_detector(op->else_case.value());
90
    }
91
    if (producer_buffer_detector.has_producer_buffer_) {
92
93
94
95
96
97
      InsertBuffer(op->condition);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
98
99
100
    ProducerBufferDetector producer_buffer_detector(producer_buffers_);
    producer_buffer_detector(op->body);
    if (producer_buffer_detector.has_producer_buffer_) {
101
102
103
104
105
106
      InsertBuffer(op->min);
      InsertBuffer(op->extent);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

107
108
109
110
111
112
113
  void VisitStmt_(const BufferStoreNode *op) final {
    if (producer_buffers_.count(op->buffer.get())) {
      InsertBuffer(op->value);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

114
115
116
117
  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
      for (auto arg : op->args) {
        if (auto buffer_load = arg.as<BufferLoadNode>()) {
118
          producer_buffers_.insert(buffer_load->buffer.get());
119
120
121
122
123
        }
      }
    }
  }

124
private:
125
  std::unordered_set<const BufferNode *> producer_buffers_;
126
127
};

128
class WarpSpecializedRoleMarker : public StmtVisitor {
129
public:
130
131
132
  WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

133
134
  void Prepare(const Stmt &stmt) {
    ProducerUsedBufferFinder finder;
135
    producer_buffers_ = finder.FindProducerusedBuffer(stmt);
136
137
  }

138
  Role GetRole(const StmtNode *stmt) const {
139
140
141
142
143
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

144
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
145

146
  void VisitStmt_(const EvaluateNode *op) final {
147
148
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
149
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
150
151
152
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
153
154
155
      if (call->op.same_as(loop_break())) {
        role = Role::kBoth;
      }
156
157
158
159
    }
    SetRole(op, role);
  }

160
  void VisitStmt_(const BufferStoreNode *op) final {
161
162
    auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer->data));
    bool is_shared_store = scope.rank == StorageRank::kShared;
163
    if (producer_buffers_.count(op->buffer.get())) {
164
165
166
      SetRole(op, Role::kBoth);
      return;
    }
167
168
169
170
171
172
173
174
175
176
177
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
178
179
    if (reads.empty())
      role = Role::kConsumer;
180
181
182
183
184
185
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
186
187
    if (role == Role::kProducer)
      has_simt_copy_ = true;
188
189
190
    SetRole(op, role);
  }

191
  void VisitStmt_(const SeqStmtNode *op) final {
192
193
194
195
196
197
198
199
200
201
202
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

203
  void VisitStmt_(const IfThenElseNode *op) final {
204
205
206
207
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
208
209
      if (role != role_else)
        role = Role::kBoth;
210
211
212
213
    }
    SetRole(op, role);
  }

214
  void VisitStmt_(const BlockRealizeNode *op) final {
215
216
217
218
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

219
220
221
222
223
224
  void VisitStmt_(const AllocateNode *op) final {
    StmtVisitor::VisitStmt_(op);
    Role role = Role::kConsumer;
    SetRole(op, role);
  }

225
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
226
227
228
229
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

230
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
231
  void VisitStmt_(const WhileNode *op) final { HandleBodyStmt(op); }
232
233
234
235
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
236
237
238
239
240

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

241
242
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
243
  Map<Var, Buffer> buffer_data_to_buffer_;
244
  std::unordered_map<const StmtNode *, Role> map_;
245
246
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
247
  std::unordered_set<const BufferNode *> producer_buffers_;
248
249
250
};

static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
251
  return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
252
253
}

254
255
256
257
258
259
260
261
262
static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
                              PrimExpr pred = 1) {
  Array<PrimExpr> args = {makeGetBarrier(barrier_id)};
  if (cta_id != -1) {
    args.push_back(cta_id);
    args.push_back(pred);
  }
  return Evaluate(
      Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args));
263
264
265
}

static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
266
267
  auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
                   {makeGetBarrier(barrier_id)});
268
269
270
271
  return Evaluate(call);
}

static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
272
  auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
273
                   {makeGetBarrier(barrier_id), parity});
274
275
276
277
  return Evaluate(call);
}

class ProducerTraitsCollector : public StmtExprVisitor {
278
public:
279
280
  ProducerTraitsCollector() { Clear(); }

281
  void Clear() { has_simt_copy = false; }
282
283
284
285
286

  void Collect(Stmt stmt) { VisitStmt(stmt); }

  bool HasSimtCopy() { return has_simt_copy; }

287
private:
288
289
290
291
292
293
294
295
296
297
298
299
  void VisitStmt_(const IfThenElseNode *op) final {
    bool old_in_if_cond = in_if_cond_;
    in_if_cond_ = true;
    VisitExpr(op->condition);
    in_if_cond_ = old_in_if_cond;

    VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      VisitStmt(op->else_case.value());
    }
  }

300
  void VisitExpr_(const BufferLoadNode *op) final {
301
302
303
    if (!in_if_cond_) {
      has_simt_copy = true;
    }
304
305
306
307
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_simt_copy;
308
  bool in_if_cond_ = false;
309
310
311
312
};

// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
313
public:
314
315
316
317
318
319
  static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
    MbarrierRewriter rewriter;
    rewriter.producer_barrier_idx_ = barrier_id;
    return rewriter(stmt);
  }

320
321
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
322
    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
323
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
324
325
326
327
328
329
330
331
332
333
      Call access_ptr = Downcast<Call>(call->args[2]);
      ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
      call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
    }
    return call;
  }
  PrimExpr producer_barrier_idx_;
};

class ThreadIdxRewriter : public StmtExprMutator {
334
public:
335
336
337
338
  static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced,
                      PrimExpr thread_extent, bool do_shuffle = false) {
    auto rewriter =
        ThreadIdxRewriter(thread_var, replaced, thread_extent, do_shuffle);
339
340
341
    return rewriter(stmt);
  }

342
private:
343
344
345
346
  ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent,
                    bool do_shuffle)
      : thread_var_(thread_var), replaced_(replaced),
        thread_extent_(thread_extent), do_shuffle_(do_shuffle) {}
347

348
  PrimExpr VisitExpr_(const VarNode *var) final {
349
350
351
352
353
354
355
    if (var == thread_var_.get()) {
      return replaced_;
    } else {
      return StmtExprMutator::VisitExpr_(var);
    }
  }

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
  Stmt VisitStmt_(const IfThenElseNode *op) final {
    auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
      return parameter == thread_var_.get();
    };
    maybe_thread_opt_ = false;
    if (!op->else_case.defined() && op->condition.as<EQNode>() &&
        UsesVar(op->condition, f_uses_thread_index) &&
        !(UsesVar(op->then_case, f_uses_thread_index))) {
      auto eq_op = Downcast<EQ>(op->condition);
      if (eq_op->a.as<VarNode>() == thread_var_.get() ||
          eq_op->b.as<VarNode>() == thread_var_.get()) {
        maybe_thread_opt_ = true;
      }
      maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_;
    }
    if (maybe_thread_opt_)
      return IfThenElse(
          Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
          StmtExprMutator::VisitStmt(op->then_case), std::nullopt);
    else
      return StmtExprMutator::VisitStmt_(op);
  }

379
380
  Var thread_var_;
  PrimExpr replaced_;
381
382
383
  PrimExpr thread_extent_;
  bool maybe_thread_opt_ = false;
  bool do_shuffle_;
384
385
};

386
387
388
389
390
391
Block MakeGroupBlock(const Stmt &stmt,
                     const Map<String, ObjectRef> &annotations) {
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ stmt,
              /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
              /*annotations=*/annotations);
392
393
394
395
396
397
398
399
400
401
402
  return block;
}

struct OpInfo {
  int group_size, order, stage;
  std::vector<int> group;
};
struct PipelineInfo {
  std::vector<OpInfo> op_infos;

  PipelineInfo() = default;
403
404
  PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
               Array<Integer> stage_info) {
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    int n = static_cast<int>(group_info.size());
    ICHECK(n == static_cast<int>(order_info.size()));
    ICHECK(n == static_cast<int>(stage_info.size()));
    // int cur_id = 0;
    for (int i = 0; i < n; i++) {
      OpInfo op_info;
      op_info.group_size = group_info[i].size();
      for (int j = 0; j < op_info.group_size; j++) {
        op_info.group.push_back(group_info[i][j].as<IntImmNode>()->value);
      }
      op_info.order = order_info[i].as<IntImmNode>()->value;
      op_info.stage = stage_info[i].as<IntImmNode>()->value;
      op_infos.push_back(op_info);
    }
  }

421
  PipelineInfo(const PipelineInfo &other) {
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    for (auto op_info : other.op_infos) {
      op_infos.push_back(op_info);
    }
  }

  std::pair<int, int> FindStmt(int stmt_idx) {
    for (size_t i = 0; i < op_infos.size(); i++) {
      for (size_t j = 0; j < op_infos[i].group.size(); j++) {
        if (op_infos[i].group[j] == stmt_idx) {
          return std::make_pair(i, j);
        }
      }
    }
    return std::make_pair(-1, -1);
  }

  void UpdateOrder(int order) {
    for (int i = 0; i < static_cast<int>(op_infos.size()); i++) {
      if (op_infos[i].order >= order && op_infos[i].order > 0) {
        op_infos[i].order++;
      }
    }
  }

  int SplitOp(int stmt_idx) {
    auto pair = FindStmt(stmt_idx);
    int op_idx = pair.first;
    int inner_idx = pair.second;
    ICHECK(op_idx != -1);
    ICHECK(inner_idx != -1);
    OpInfo half0;
    OpInfo half1;
    // The order to do sync
    int sync_order = op_infos[op_idx].order + 1;
    UpdateOrder(sync_order);

    half0.group_size = inner_idx + 1;
    half0.order = op_infos[op_idx].order;
    half0.stage = op_infos[op_idx].stage;
    for (int i = 0; i <= inner_idx; i++) {
      half0.group.push_back(op_infos[op_idx].group[i]);
    }
    half1.group_size = op_infos[op_idx].group_size - inner_idx - 1;
    half1.order = op_infos[op_idx].order + 2;
    half1.stage = op_infos[op_idx].stage;
    for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) {
      half1.group.push_back(op_infos[op_idx].group[i]);
    }
    op_infos.erase(op_infos.begin() + op_idx);
    if (half0.group_size > 0) {
      op_infos.insert(op_infos.begin() + op_idx, half0);
    }
    if (half1.group_size > 0) {
      UpdateOrder(half1.order);
      op_infos.insert(op_infos.begin() + op_idx + 1, half1);
    }
    return sync_order;
  }

  void PrintPipelineInfo() {
    std::cout << "Print op_infos:" << std::endl;
    for (size_t i = 0; i < op_infos.size(); i++) {
484
485
      std::cout << i << " " << op_infos[i].group_size << " "
                << op_infos[i].order << " " << op_infos[i].stage << std::endl;
486
487
488
489
490
491
    }
    std::cout << "End of print" << std::endl;
  }
};

class GroupOpRewriter : public StmtExprMutator {
492
public:
493
494
  GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}

495
496
private:
  Stmt VisitStmt_(const ForNode *op) final {
497
498
499
500
501
502
503
504
505
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));
    auto original_node = (op->body).as<SeqStmtNode>();
    if (!original_node) {
      return GetRef<For>(op);
    }
    Array<Stmt> new_body;
    int cur_id = 0;
    for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
506
507
      if (pipeline_info_.op_infos[i].group_size == 0)
        continue;
508
      Array<Stmt> block_stmt;
509
510
      for (int j = 0;
           j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
511
        // ICHECK(group_info_[i][j].as<IntImmNode>());
512
513
        // int index =
        // static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
514
515
516
517
518
519
        ICHECK(original_node->seq[cur_id].as<BlockNode>());
        auto block = original_node->seq[cur_id].as<BlockNode>();
        // TODO: handle nested seqstmt
        block_stmt.push_back(block->body);
        cur_id++;
      }
520
521
522
523
      new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                            ? block_stmt[0]
                                            : SeqStmt(std::move(block_stmt)),
                                        annotations));
524
525
526
527
528
529
530
    }
    Array<Integer> order_anno;
    Array<Integer> stage_anno;
    for (auto op_info : pipeline_info_.op_infos) {
      order_anno.push_back(Integer(op_info.order));
      stage_anno.push_back(Integer(op_info.stage));
    }
531
    Map<String, Any> for_annotations = op->annotations;
532
533
534
    for_annotations.erase("tl_pipeline_group");
    for_annotations.Set("software_pipeline_order", order_anno);
    for_annotations.Set("software_pipeline_stage", stage_anno);
535
536
537
538
    For new_for =
        For(op->loop_var, op->min, op->extent, op->kind,
            new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
            op->thread_binding, for_annotations);
539
540
541
542
543
    return new_for;
  }

  PipelineInfo pipeline_info_;
};
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

class WgMMACollector : public StmtExprVisitor {
public:
  WgMMACollector() = default;

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tl_gemm()) || op->op.same_as(tl_gemm_sp())) {
      auto op_name = std::string(op->args[0].as<StringImmNode>()->value);
      if (has_wgmma_) {
        has_wgmma_ =
            op_name.find("false") == std::string::npos && !in_if_scope_;
      }
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    in_if_scope_ = true;
    StmtExprVisitor::VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      StmtExprVisitor::VisitStmt(op->else_case.value());
    }
    in_if_scope_ = false;
  }

  static bool HasWgMMA(Stmt stmt) {
    auto collector = WgMMACollector();
    collector(stmt);
    return collector.has_wgmma_;
  }

  bool has_wgmma_{true};
  bool in_if_scope_{false};
};

579
class WSCodeEmitter : public StmtMutator {
580
public:
581
  /**
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
   * @brief Construct a warp-specialized code emitter configured for producer or
   * consumer emission.
   *
   * Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered
   * code for a single warp-specialized block. The emitter is configured with
   * the loop/thread iteration variable, buffer mapping, role marker used to
   * classify statements, and two flags that control emission behavior:
   *
   * - `mbarrier_only`: when true, emission is restricted to barrier-related
   * operations only.
   * - `only_has_wgmma`: when true, the emitter will account for the presence of
   * WgMMA (workgroup MMA) operations when computing barrier/thread gating
   * behavior.
   *
   * @param is_emitting_producer True to emit producer-side groups; false to
   * emit consumer-side groups.
   * @param thread_iv IterVar representing the thread iteration variable
   * (threadIdx.*) whose Var is used for thread-index rewrites and gating.
   * @param buffer_data_to_buffer Map from buffer data Var to the corresponding
   * Buffer (used to resolve buffer references during emission).
   * @param marker Role marker that classifies statements as
   * producer/consumer/both; used to filter which statements are emitted on this
   * path.
   * @param mbarrier_only If true, restrict emission to mbarrier-related
   * statements and helpers.
   * @param only_has_wgmma If true, adjust emission and barrier-thread-count
   * logic for blocks that contain WgMMA operations.
   */
  WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
611
                Map<Var, Buffer> buffer_data_to_buffer,
612
                const WarpSpecializedRoleMarker &marker,
613
                bool mbarrier_only = false, bool only_has_wgmma = false)
614
      : is_emitting_producer_(is_emitting_producer),
615
        buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
616
617
        thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
        only_has_wgmma_(only_has_wgmma) {}
618

619
  /**
620
621
622
623
624
625
626
627
628
   * @brief Whether a SIMT-style bulk copy was detected.
   *
   * Returns true when a simulated SIMT (thread-parallel) copy pattern was
   * observed during analysis/emission, which can affect barrier insertion and
   * copy emission.
   *
   * @return true if a SIMT copy was detected; false otherwise.
   */
  bool hasSimtCopy() const { return has_simt_copy_; }
629

630
631
private:
  template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
632
    Role role = marker_.GetRole(op);
633
634
635
636
637
    if (mbarrier_only_) {
      if (role != Role::kProducer)
        return StmtMutator::VisitStmt_(op);
    }
    if (role == Role::kBoth) {
638
      return StmtMutator::VisitStmt_(op);
639
    } else if ((role == Role::kProducer) == is_emitting_producer_) {
640
      return GetRef<Stmt>(op);
641
    } else {
642
      return Evaluate(0);
643
    }
644
645
  }

646
  /**
647
648
   * @brief Visit and transform a SeqStmt node, emitting grouped blocks with
   * barrier synchronization according to producer/consumer roles.
649
650
   *
   * This method examines the sequence to determine whether producer-side
651
652
   * synchronization is required (based on marker_ roles). If no producer sync
   * is needed it delegates to FilterByRole. Otherwise it:
653
654
655
656
   * - Recursively visits and transforms each child statement.
   * - Extracts an acquire/release sync pattern for the sequence via
   *   ExtractSyncPattern.
   * - For producer emission (is_emitting_producer_ == true):
657
658
   *   - Skips consumer-only statements unless marker_ marks a statement as
   * Both, in which case the statement is emitted as its own group.
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
   *   - For each statement, inserts parity waits for acquire patterns, rewrites
   *     release statements with MbarrierRewriter using a computed barrier id,
   *     collects SimT-copy presence (setting has_simt_copy_ and inserting
   *     cp.async barriers when found), optionally emits arrive barriers for
   *     release-after events, and emits each resulting set of statements as a
   *     group block annotated with "stmt_group".
   * - For consumer emission (is_emitting_producer_ == false):
   *   - Skips producer-only statements.
   *   - Inserts parity waits for acquire patterns, appends the transformed
   *     statement, and emits arrive barriers for release-after events. When
   *     only_has_wgmma_ is set, the arrive barrier uses a per-thread predicate
   *     (FloorMod(thread_var_,128)==0) with CTA=0; otherwise a full arrive is
   *     emitted.
   *   - Recomputes pipeline_info_ to drop producer-only ops.
   *
   * Side effects / state updates:
   * - Increments num_barriers_ by (number of extracted patterns * num_stages_).
   * - May set has_simt_copy_ when a SimT copy is detected in producer rewrites.
   * - Inserts barrier ids into released_barrier_ for release-after events.
   * - Updates pipeline_info_ for the consumer path to remove producer ops.
   *
   * The resulting statements are emitted as grouped blocks (via MakeGroupBlock)
   * with the annotation "stmt_group" and returned as either a single Stmt (if
   * there's only one group) or a SeqStmt containing the grouped blocks.
   *
   * @return Stmt The transformed statement (either a single group block or a
   * SeqStmt of group blocks).
   */
687
  Stmt VisitStmt_(const SeqStmtNode *op) final {
688

689
690
691
692
693
694
695
    bool has_producer = false;
    for (auto stmt : op->seq) {
      if (marker_.GetRole(stmt) == Role::kProducer) {
        has_producer = true;
        break;
      }
    }
696
697
698
699
    bool need_producer_sync =
        has_producer && marker_.GetRole(op) == Role::kBoth;
    if (!need_producer_sync)
      return FilterByRole(op);
700

701
702
    auto seq_transformed =
        op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
703
704

    auto map = ExtractSyncPattern(op->seq);
705

706
707
708
709
710
711
712
713
714
715
716
717
718
719
    /*
      std::cout << "Print ExtractSyncPattern" << std::endl;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
        << map.release_after[i] << std::endl;
      }
      std::cout << "Print sync pattern" << std::endl;
      for (auto pattern : map.patterns) {
        std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
        std::endl;
      }
      std::cout << "End of ExtractSyncPattern" << std::endl;
      pipeline_info_.PrintPipelineInfo();
    */
720
721
722
723
    Array<Stmt> new_body;
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));

724
    if (is_emitting_producer_) { // producer case
725
726
727
      ProducerTraitsCollector collector;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
728
729
730
731
732
733
734
735
736
737
738
        if (!mbarrier_only_) {
          if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
            continue;
          if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
            block_stmt.push_back(seq_transformed[i]);
            new_body.push_back(MakeGroupBlock(
                block_stmt.size() == 1 ? block_stmt[0]
                                       : SeqStmt(std::move(block_stmt)),
                annotations));
            continue;
          }
739
        }
740

741
        for (int pattern_idx : map.acquire[i]) {
742
          PrimExpr acquire_barrier_id =
743
744
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
745
746
                                ? bitwise_xor(parity_, 1)
                                : parity_;
747
748
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
749
750
751
752
753
754
755
756
        ICHECK(map.release[i].size() > 0);
        for (size_t j = 0; j < map.release[i].size(); j++) {
          int pattern_idx = map.release[i][j];
          PrimExpr release_barrier_id =
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          auto stmt =
              MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
          collector.Collect(stmt);
757
          block_stmt.push_back(stmt);
758
          if (collector.HasSimtCopy()) {
759
            block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
760
            has_simt_copy_ = true;
761
          }
762
763
764
765
766
767
768
769
770
771
772
773
          if (map.release_after[i][j]) {
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
          }
          collector.Clear();
          new_body.push_back(MakeGroupBlock(
              block_stmt.size() == 1 ? block_stmt[0]
                                     : SeqStmt(std::move(block_stmt)),
              annotations));
774
775
        }
      }
776
    } else { // consumer case
777
778
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
779
780
        if (marker_.GetRole(op->seq[i]) == Role::kProducer)
          continue;
781
        for (int pattern_idx : map.acquire[i]) {
782
          PrimExpr acquire_barrier_id =
783
784
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
785
786
                                ? bitwise_xor(parity_, 1)
                                : parity_;
787
788
789
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
        block_stmt.push_back(seq_transformed[i]);
790
791
792
793
794
        for (size_t j = 0; j < map.release[i].size(); j++) {
          if (map.release_after[i][j]) {
            int pattern_idx = map.release[i][j];
            PrimExpr release_barrier_id =
                stage_ + num_barriers_ + num_stages_ * pattern_idx;
795
796
797
798
799
            if (only_has_wgmma_)
              block_stmt.push_back(makeArriveBarrier(
                  release_barrier_id, 0, EQ(FloorMod(thread_var_, 128), 0)));
            else
              block_stmt.push_back(makeArriveBarrier(release_barrier_id));
800
801
802
803
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
804
805
          }
        }
806
807
808
809
        new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                              ? block_stmt[0]
                                              : SeqStmt(std::move(block_stmt)),
                                          annotations));
810
811
812
813
      }
      // Filter out the producer stmts
      int cur_id = 0;
      PipelineInfo new_pipeline_info;
814
815
      for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
           i++) {
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        auto op_info = pipeline_info_.op_infos[i];
        bool is_producer = false;
        for (int j = 0; j < op_info.group_size; j++) {
          if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) {
            is_producer = true;
          }
          cur_id++;
        }
        if (is_producer) {
          ICHECK(op_info.group_size == 1);
        } else {
          new_pipeline_info.op_infos.push_back(op_info);
        }
      }
      pipeline_info_ = new_pipeline_info;
    }

    num_barriers_ += map.patterns.size() * num_stages_;

    ICHECK(new_body.size() > 0);
    return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
  }

839
  Stmt VisitStmt_(const ForNode *op) final {
840
841
    int num_stages = 1;
    auto num_stages_anno = op->annotations.Get("num_stages");
842
843
844
    if (num_stages_anno) {
      ICHECK(num_stages_anno->as<IntImmNode>());
      num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
845
846
      ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
    }
847
    loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});
848
849
850
851

    Array<Array<Integer>> group_info_array;
    Array<Integer> order_info_array;
    Array<Integer> stage_info_array;
852

853
    auto group_anno = op->annotations.Get("tl_pipeline_group");
854
855
    if (group_anno) {
      group_info_array = Downcast<Array<Array<Integer>>>(group_anno.value());
856
857
    }
    auto order_anno = op->annotations.Get("tl_pipeline_order");
858
859
    if (order_anno) {
      order_info_array = Downcast<Array<Integer>>(order_anno.value());
860
861
    }
    auto stage_anno = op->annotations.Get("tl_pipeline_stage");
862
863
    if (stage_anno) {
      stage_info_array = Downcast<Array<Integer>>(stage_anno.value());
864
865
    }

866
867
    PipelineInfo pipeline_info(group_info_array, order_info_array,
                               stage_info_array);
868
    if (pipeline_info.op_infos.size() > 0) {
869
870
      ICHECK(pipeline_info_.op_infos.size() == 0)
          << "Nested pipeline not supported.";
871
872
873
874
875
876
877
878
879
    }

    PrimExpr parity_before = std::move(parity_);
    PrimExpr stage_before = std::move(stage_);
    int num_stages_before = num_stages_;
    PipelineInfo pipeline_info_before = pipeline_info_;

    num_stages_ = num_stages;
    pipeline_info_ = pipeline_info;
880
    PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
881
    for (size_t i = 1; i < loop_stack_.size(); ++i) {
882
883
      linear_index = linear_index * loop_stack_[i].extent +
                     (loop_stack_[i].loop_var - loop_stack_[i].min);
884
885
886
887
    }
    stage_ = FloorMod(linear_index, num_stages);
    parity_ = FloorMod(
        parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
888
889
890
    auto result = FilterByRole(op);

    Stmt grouped_for_node;
891
892
    if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 &&
        !is_emitting_producer_) {
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
      GroupOpRewriter group_op_rewriter(pipeline_info_);
      auto for_node = Downcast<For>(result);
      grouped_for_node = group_op_rewriter(for_node);
    }

    parity_ = std::move(parity_before);
    stage_ = std::move(stage_before);
    num_stages_ = num_stages_before;
    pipeline_info_ = pipeline_info_before;

    // remove pipeline annotation
    auto for_node = result.as<For>();
    if (result.as<ForNode>()) {
      auto for_node = Downcast<For>(result);
      for_node.CopyOnWrite()->annotations.erase("num_stages");
      if (is_emitting_producer_ || group_info_array.size() == 0) {
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
      }
912
      if (is_emitting_producer_ || !group_anno ||
913
          group_info_array.size() == 0) {
914
        loop_stack_.pop_back();
915
916
        return for_node;
      }
917
      loop_stack_.pop_back();
918
919
      return grouped_for_node;
    }
920
    loop_stack_.pop_back();
921
922
923
    return result;
  }

924
925
926
927
928
929
930
  Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BlockNode *op) final {
931
932
933
    ICHECK(0);
    return Stmt();
  }
934
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
935
936
937
938
939
940
941
942
943
    ICHECK(0);
    return Stmt();
  }

  struct SyncPattern {
    int release_idx, acquire_idx;
  };

  struct SyncPatternMap {
944
945
946
    std::vector<std::vector<int>> acquire;
    std::vector<std::vector<int>> release;
    std::vector<std::vector<bool>> release_after;
947
    std::vector<SyncPattern> patterns;
948
949
950
951
952
953
954
955
956
957

    void resize(size_t n) {
      acquire.resize(n);
      release.resize(n);
      release_after.resize(n);
    }

    bool is_loop_dependency(int pattern_idx) {
      return patterns[pattern_idx].release_idx >
             patterns[pattern_idx].acquire_idx;
958
959
960
    }
  };

961
962
963
  std::vector<SyncPattern>
  CreateBaseSyncPairs(Array<Stmt> seq_stmt,
                      const std::vector<bool> &is_producer) {
964
    const int n = seq_stmt.size();
965
    std::vector<std::set<const BufferNode *>> reads, writes;
966
967
968
    reads.reserve(n);
    writes.reserve(n);
    for (int i = 0; i < n; i++) {
969
970
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"",
971
972
                  /*body*/ seq_stmt[i]);
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
973
      std::set<const BufferNode *> read_set, write_set;
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
      for (auto region : access[0]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          read_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          read_set.insert(region->buffer.get());
        }
      }
      for (auto region : access[1]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          write_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          write_set.insert(region->buffer.get());
        }
      }
990
991
992
993
      reads.push_back(std::move(read_set));
      writes.push_back(std::move(write_set));
    }

994
995
    auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
                           const std::set<const BufferNode *> &rhs) {
996
      for (auto ptr : lhs)
997
998
        if (rhs.count(ptr))
          return true;
999
1000
1001
1002
1003
1004
1005
1006
1007
      return false;
    };

    std::vector<SyncPattern> sync_patterns;
    // producer_release consumer_acquire,
    // inject before the first consumer stmt for each producer
    for (int i = 0; i < n; i++) {
      for (int j = i + 1; j < n; j++) {
        if (is_producer[i] != is_producer[j] &&
1008
1009
            (intersect_fn(writes[i], reads[j]) ||
             intersect_fn(reads[i], writes[j]))) {
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
          sync_patterns.push_back({i, j});
          break;
        }
      }
    }

    // consumer_release producer_acquire
    // valid when is_loop is true
    // inject before the earliest producer stmt for each consumer
    bool in_loop = !is_zero(parity_);
    if (in_loop) {
      for (int i = 0; i < n; i++) {
        for (int j = 0; j < i; j++) {
          if (is_producer[i] != is_producer[j] &&
1024
1025
              (intersect_fn(writes[i], reads[j]) ||
               intersect_fn(reads[i], writes[j]))) {
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
            sync_patterns.push_back({i, j});
            break;
          }
        }
      }
    }

    return sync_patterns;
  }

1036
1037
1038
  static std::vector<SyncPattern>
  RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
                           const std::vector<bool> &is_producer) {
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    /*
      Simplify multiple release-acquire pairs into one
      ------------------
        Produce(A)
        Produce(B)
        Consume(A, B)
      ------------------
      [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)]

      Or
      ------------------
        Produce(A, B)
        Consume(A)
        Consume(B)
      ------------------
      [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)]
    */
    int M = sync_patterns.size();
    std::vector<bool> removed(M, false);
    for (int i = 0; i < M; i++) {
      for (int j = 0; j < M; j++) {
        if (is_producer[sync_patterns[i].acquire_idx] ==
                is_producer[sync_patterns[j].acquire_idx] &&
            sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx &&
            sync_patterns[i].release_idx < sync_patterns[j].release_idx)
          removed[i] = true;
      }
    }

    std::vector<SyncPattern> sync_pattern_cleaned;
    sync_pattern_cleaned.reserve(M);
    for (int i = 0; i < M; i++)
1071
1072
      if (!removed[i])
        sync_pattern_cleaned.push_back(sync_patterns[i]);
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085

    return sync_pattern_cleaned;
  }

  SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) {
    size_t num_stmts = seq_stmt.size();
    std::vector<bool> is_producer;
    is_producer.reserve(num_stmts);
    for (auto stmt : seq_stmt) {
      is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer);
    }

    auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
1086
1087
    auto sync_patterns =
        RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
1088
1089

    // for (auto pattern : sync_patterns) {
1090
1091
    //   std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
    //   std::endl;
1092
1093
1094
    // }

    SyncPatternMap map;
1095
    map.resize(num_stmts);
1096
    map.patterns = sync_patterns;
1097

1098
    for (size_t i = 0; i < sync_patterns.size(); i++) {
1099
1100
1101
1102
1103
1104
      int acquire_idx = sync_patterns[i].acquire_idx;
      int release_idx = sync_patterns[i].release_idx;

      map.acquire[acquire_idx].push_back(i);
      map.release[release_idx].push_back(i);
      map.release_after[release_idx].push_back(true);
1105
1106
    }

1107
    std::vector<int> cur_consumer_barrier, cur_producer_barrier;
1108
1109
    for (int i = num_stmts - 1; i >= 0; i--) {
      if (is_producer[i]) {
1110
1111
1112
1113
1114
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_producer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
1115
        } else {
1116
1117
1118
          for (auto pattern_idx : map.release[i]) {
            cur_producer_barrier.push_back(pattern_idx);
          }
1119
1120
        }
      } else {
1121
1122
1123
1124
1125
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_consumer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
1126
        } else {
1127
1128
1129
          for (auto pattern_idx : map.release[i]) {
            cur_consumer_barrier.push_back(pattern_idx);
          }
1130
1131
1132
1133
1134
1135
1136
1137
1138
        }
      }
    }
    return map;
  }

  const bool is_emitting_producer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  std::unordered_set<int> released_barrier_;
1139
  const WarpSpecializedRoleMarker &marker_;
1140
1141
1142
1143
1144

  int num_barriers_ = 0;
  PrimExpr parity_ = 0;
  PrimExpr stage_ = 0;
  int num_stages_ = 1;
1145
  std::vector<LoopInfo> loop_stack_;
1146
  Var thread_var_;
1147
  bool mbarrier_only_ = false;
1148
1149
  PipelineInfo pipeline_info_;
  friend class WarpSpecializedRewriter;
1150
1151
  bool only_has_wgmma_ = false;
  bool has_simt_copy_ = false;
1152
1153
1154
};

class WarpSpecializedRewriter : public StmtExprMutator {
1155
public:
1156
1157
1158
1159
1160
1161
  WarpSpecializedRewriter(bool disable_warp_specialized,
                          bool disable_shuffle_elect)
      : disable_warp_specialized_(disable_warp_specialized),
        disable_shuffle_elect_(disable_shuffle_elect) {}
  static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized,
                             bool disable_shuffle_elect) {
1162
1163
1164
    // Check if function only uses threadIdx.x before proceeding
    if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
      LOG(WARNING) << "WarpSpecialize will be disabled because the program "
1165
                      "uses thread tags other than threadIdx.x."
1166
1167
1168
1169
1170
1171
                   << "If you want to use warp specialization, please refactor "
                      "your program to use threadIdx.x only";
      // Return original function unchanged if other thread tags are found
      return f;
    }

1172
1173
    auto T = WarpSpecializedRewriter(disable_warp_specialized,
                                     disable_shuffle_elect);
1174
    T.buffer_lca_ = DetectBufferAccessLCA(f);
1175
1176
    for (auto [buffer, _] : T.buffer_lca_)
      T.buffer_data_to_buffer_.Set(buffer->data, buffer);
1177
1178
1179
1180
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

1181
1182
private:
  Stmt VisitStmt_(const AttrStmtNode *op) final {
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

1200
1201
1202
1203
  // If users define a thread binding, we will replace the thread binding with
  // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
  // the same as the thread extent
  Stmt VisitStmt_(const ForNode *op) final {
1204
1205
1206
1207
1208
1209
1210
    ICHECK(thread_iv_.defined());
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (for_node->kind == ForKind::kThreadBinding) {
      ICHECK(for_node->thread_binding.defined());
      String thread_tag = for_node->thread_binding.value()->thread_tag;
      ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
      Var thread_iv = Downcast<Var>(for_node->loop_var);
1211
      Stmt new_body =
1212
          ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_, 0);
1213
1214
1215
1216
1217
      return new_body;
    }
    return for_node;
  }

1218
  /**
1219
1220
   * @brief Rewrite a BlockRealize for warp specialization, inserting barriers
   * and emitting producer/consumer bodies.
1221
1222
1223
1224
1225
   *
   * This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_)
   * is defined and warp-specialization is applicable. It:
   * - Determines producer/consumer roles via WarpSpecializedRoleMarker and
   *   returns the original block if no producer is detected.
1226
1227
   * - If warp specialization is disabled, emits only mbarrier initialization
   * and the mbarrier-only transformed body.
1228
1229
1230
   * - Otherwise, detects WgMMA usage for the block body and constructs separate
   *   WSCodeEmitter instances for producer and consumer paths (propagating the
   *   WgMMA flag to the consumer emitter).
1231
1232
1233
   * - Generates producer/consumer code, applies register hint calls
   * (set_max_nreg) when available, and rewrites thread indices with
   * ThreadIdxRewriter to partition threads between producer and consumer roles.
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
   * - Computes and initializes a list of mbarrier handles with per-barrier
   *   arrive thread counts (taking SIMT-copy and WgMMA cases into account).
   * - Wraps the transformed body in an IfThenElse that dispatches producer vs
   *   consumer based on thread index, and annotates the region with the
   *   "kWarpSpecializationScope" attribute that contains producer/consumer
   *   thread extents.
   *
   * Side effects:
   * - May update member state: only_has_wgmma_, updated_thread_extent_,
   *   need_update_thread_extent_.
   * - May abort via ICHECK if invariants (e.g., matching barrier counts) are
   *   violated.
   *
   * @return The possibly rewritten BlockRealize statement (original when no
   *         warp-specialization is applied or thread_iv_ is undefined).
   */
1250
1251
1252
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
1253
1254
1255
1256
1257
1258
    if (!thread_iv_.defined()) {
      return block_realize;
    }

    Block block = block_realize->block;
    WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
1259
    marker.Prepare(block);
1260
1261
1262
1263
1264
1265
    marker(block);
    if (!marker.HasProducer()) {
      // Cannot detect any producer here, directly return.
      return block_realize;
    }

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
    if (disable_warp_specialized_) {
      WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
                                     marker, true);
      auto code = mbarrier_emitter(block->body);
      int num_barriers = mbarrier_emitter.num_barriers_;
      Array<PrimExpr> barrier_num_threads;
      barrier_num_threads.reserve(num_barriers);
      PrimExpr arrive_thread_count = thread_iv_->dom->extent;
      for (int i = 0; i < num_barriers; i++) {
        barrier_num_threads.push_back(arrive_thread_count);
      }
      Stmt init_barrier = Evaluate(Call(
1278
          DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1279
1280
1281
1282
      block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
      block_realize.CopyOnWrite()->block = block;
      return block_realize;
    }
1283
    only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
1284
    WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
1285
1286
    WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
                           false, only_has_wgmma_);
1287
1288
1289
1290
1291
    Stmt producer_code = producer(block->body);
    Stmt consumer_code = consumer(block->body);
    PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
    PrimExpr producer_thread_extent = thread_iv_->dom->extent;
    // Need one warp-group for bulk-copy only case
1292
1293
    if (!marker.HasSimtCopy())
      producer_thread_extent = 128;
1294
1295

    updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
1296
1297
1298
1299
1300
1301
1302
1303

    producer_code = ThreadIdxRewriter::Rewrite(
        producer_code, thread_iv_->var,
        thread_iv_->var - consumer_thread_extent, producer_thread_extent,
        !disable_shuffle_elect_);
    consumer_code = ThreadIdxRewriter::Rewrite(
        consumer_code, thread_iv_->var, thread_iv_->var, consumer_thread_extent,
        !disable_shuffle_elect_);
1304
1305
1306
1307
1308
1309
1310
1311
    need_update_thread_extent_ = true;

    ICHECK(producer.num_barriers_ == consumer.num_barriers_)
        << producer.num_barriers_ << " " << consumer.num_barriers_;
    int num_barriers = consumer.num_barriers_;
    Array<PrimExpr> barrier_num_threads;
    barrier_num_threads.reserve(num_barriers);
    for (int i = 0; i < num_barriers; i++) {
1312
1313
1314
      PrimExpr arrive_thread_count =
          producer.released_barrier_.count(i)
              ? (producer.hasSimtCopy() ? producer_thread_extent : 1)
1315
1316
              : (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
                                 : consumer_thread_extent);
1317
1318
1319
      barrier_num_threads.push_back(arrive_thread_count);
    }

1320
    Stmt init_barrier = Evaluate(Call(
1321
        DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1322
1323
    Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
                           producer_code, consumer_code);
1324
    // Add an attr here to handle the partial thread count in ThreadSync pass.
1325
1326
    Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
                                  Downcast<IntImm>(consumer_thread_extent)};
1327
    body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body);
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341

    block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

  WarpSpecializedRewriter() = default;

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
1342
  bool disable_warp_specialized_ = false;
1343
  bool disable_shuffle_elect_ = false;
1344
  bool only_has_wgmma_ = false;
1345
1346
};

1347
1348
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
1349
  // return true means this aws will be disabled
1350
1351
1352
  static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
    WarpSpecializedDetector detector;
    detector.VisitStmt(stmt);
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    if (detector.has_warp_specialization_) {
      LOG(WARNING) << "Auto warp specialization will be disabled because warp "
                      "specialization is manually enabled";
      return true;
    }
    if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
      LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
                      "and mbarrier are both present";
      return true;
    }
    return false;
1364
1365
1366
1367
1368
  }

  WarpSpecializedDetector() {
    has_tma_op_ = false;
    has_mbarrier_op_ = false;
1369
    has_warp_specialization_ = false;
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(create_list_of_mbarrier()) ||
          call->op.same_as(mbarrier_wait_parity()) ||
          call->op.same_as(builtin::ptx_arrive_barrier()) ||
          call->op.same_as(builtin::ptx_cp_async_barrier())) {
        has_mbarrier_op_ = true;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
        op->op.same_as(set_max_nreg())) {
      has_tma_op_ = true;
    }
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

1393
  void VisitStmt_(const AttrStmtNode *op) final {
1394
1395
1396
1397
    if (op->attr_key == "warp_specialize" &&
        op->value.as<IntImmNode>()->value == 1) {
      has_warp_specialization_ = true;
    }
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

1408
  bool has_tma_op_{false};
1409
  IterVar thread_var_;
1410
  bool has_mbarrier_op_{false};
1411
  bool has_warp_specialization_{false};
1412
1413
};

1414
1415
1416
1417
using namespace tir::transform;

tvm::transform::Pass WarpSpecialized() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1418
1419
    bool disable_warp_specialized =
        ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
1420
1421
    bool disable_shuffle_elect =
        ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
1422
1423
1424
    bool warp_specialized = WarpSpecializedDetector::Detect(f->body);

    if (!warp_specialized) {
1425
1426
      return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
                                                 disable_shuffle_elect);
1427
1428
    }
    return f;
1429
1430
1431
1432
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}

1433
1434
1435
1436
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
1437

1438
1439
} // namespace tl
} // namespace tvm