warp_specialized_rewriter.cc 42.5 KB
Newer Older
1
/*!
2
 * \file warp_specialized_rewriter.cc
3
4
5
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

6
#include "arith/ir_visitor_with_analyzer.h"
7
#include "tir/analysis/var_use_def_analysis.h"
8
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
13
14
15
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
16
#include "./common/collector.h"
17
18
19
20
21

namespace tvm {
namespace tl {

using namespace tir;
22
using arith::IRVisitorWithAnalyzer;
23
24
25

enum class Role { kConsumer, kProducer, kBoth };

26
class ProducerBufferDetector : public StmtExprVisitor {
27
public:
28
29
30
31
32
  ProducerBufferDetector(
      std::unordered_set<const BufferNode *> cur_producer_buffers)
      : cur_producer_buffers_(cur_producer_buffers) {}

  void clear() { has_producer_buffer_ = false; }
33
34

  void VisitExpr_(const CallNode *call) final {
35
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
36
      has_producer_buffer_ = true;
37
    }
38
    StmtExprVisitor::VisitExpr_(call);
39
40
  }

41
42
43
44
45
46
47
48
49
  void VisitExpr_(const BufferLoadNode *op) final {
    if (cur_producer_buffers_.count(op->buffer.get())) {
      has_producer_buffer_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_producer_buffer_ = false;
  std::unordered_set<const BufferNode *> cur_producer_buffers_;
50
51
52
53
54
};

class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
  auto FindProducerusedBuffer(Stmt stmt) {
55
56
57
58
59
60
61
62
63
64
    producer_buffers_.clear();
    std::unordered_set<const BufferNode *> last_producer_buffers_;
    for (;;) {
      VisitStmt(stmt);
      if (producer_buffers_ == last_producer_buffers_) {
        break;
      }
      last_producer_buffers_ = producer_buffers_;
    }
    return producer_buffers_;
65
66
67
68
69
70
71
  }

  void InsertBuffer(const PrimExpr &expr) {
    // Find the buffer that is used in the condition
    VarUseDefAnalyzer usage(Array<Var>{});
    usage(expr);
    for (const auto &buffer : usage.buffer_use_count_) {
72
      producer_buffers_.insert(buffer.first);
73
74
75
76
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
77
78
    ProducerBufferDetector producer_buffer_detector(producer_buffers_);
    producer_buffer_detector(op->then_case);
79
    if (op->else_case.defined()) {
80
      producer_buffer_detector(op->else_case.value());
81
    }
82
    if (producer_buffer_detector.has_producer_buffer_) {
83
84
85
86
87
88
      InsertBuffer(op->condition);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
89
90
91
    ProducerBufferDetector producer_buffer_detector(producer_buffers_);
    producer_buffer_detector(op->body);
    if (producer_buffer_detector.has_producer_buffer_) {
92
93
94
95
96
97
      InsertBuffer(op->min);
      InsertBuffer(op->extent);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

98
99
100
101
102
103
104
  void VisitStmt_(const BufferStoreNode *op) final {
    if (producer_buffers_.count(op->buffer.get())) {
      InsertBuffer(op->value);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

105
106
107
108
  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
      for (auto arg : op->args) {
        if (auto buffer_load = arg.as<BufferLoadNode>()) {
109
          producer_buffers_.insert(buffer_load->buffer.get());
110
111
112
113
114
        }
      }
    }
  }

115
private:
116
  std::unordered_set<const BufferNode *> producer_buffers_;
117
118
};

119
class WarpSpecializedRoleMarker : public StmtVisitor {
120
public:
121
122
123
  WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

124
125
  void Prepare(const Stmt &stmt) {
    ProducerUsedBufferFinder finder;
126
    producer_buffers_ = finder.FindProducerusedBuffer(stmt);
127
128
  }

129
  Role GetRole(const StmtNode *stmt) const {
130
131
132
133
134
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

135
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
136

137
  void VisitStmt_(const EvaluateNode *op) final {
138
139
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
140
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
141
142
143
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
144
145
146
      if (call->op.same_as(loop_break())) {
        role = Role::kBoth;
      }
147
148
149
150
    }
    SetRole(op, role);
  }

151
152
153
  void VisitStmt_(const BufferStoreNode *op) final {
    bool is_shared_store =
        op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
154
    if (producer_buffers_.count(op->buffer.get())) {
155
156
157
      SetRole(op, Role::kBoth);
      return;
    }
158
159
160
161
162
163
164
165
166
167
168
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
169
170
    if (reads.empty())
      role = Role::kConsumer;
171
172
173
174
175
176
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
177
178
    if (role == Role::kProducer)
      has_simt_copy_ = true;
179
180
181
    SetRole(op, role);
  }

182
  void VisitStmt_(const SeqStmtNode *op) final {
183
184
185
186
187
188
189
190
191
192
193
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

194
  void VisitStmt_(const IfThenElseNode *op) final {
195
196
197
198
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
199
200
      if (role != role_else)
        role = Role::kBoth;
201
202
203
204
    }
    SetRole(op, role);
  }

205
  void VisitStmt_(const BlockRealizeNode *op) final {
206
207
208
209
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

210
211
212
213
214
215
  void VisitStmt_(const AllocateNode *op) final {
    StmtVisitor::VisitStmt_(op);
    Role role = Role::kConsumer;
    SetRole(op, role);
  }

216
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
217
218
219
220
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

221
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
222
  void VisitStmt_(const WhileNode *op) final { HandleBodyStmt(op); }
223
224
225
226
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
227
228
229
230
231

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

232
233
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
234
  Map<Var, Buffer> buffer_data_to_buffer_;
235
  std::unordered_map<const StmtNode *, Role> map_;
236
237
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
238
  std::unordered_set<const BufferNode *> producer_buffers_;
239
240
241
};

static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
242
  return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
243
244
245
}

static Stmt makeArriveBarrier(PrimExpr barrier_id) {
246
247
  auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
                   {makeGetBarrier(barrier_id)});
248
249
250
251
  return Evaluate(call);
}

static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
252
253
  auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
                   {makeGetBarrier(barrier_id)});
254
255
256
257
  return Evaluate(call);
}

static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
258
  auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
259
                   {makeGetBarrier(barrier_id), parity});
260
261
262
263
  return Evaluate(call);
}

class ProducerTraitsCollector : public StmtExprVisitor {
264
public:
265
266
  ProducerTraitsCollector() { Clear(); }

267
  void Clear() { has_simt_copy = false; }
268
269
270
271
272

  void Collect(Stmt stmt) { VisitStmt(stmt); }

  bool HasSimtCopy() { return has_simt_copy; }

273
private:
274
275
276
277
278
279
280
281
282
283
284
285
  void VisitStmt_(const IfThenElseNode *op) final {
    bool old_in_if_cond = in_if_cond_;
    in_if_cond_ = true;
    VisitExpr(op->condition);
    in_if_cond_ = old_in_if_cond;

    VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      VisitStmt(op->else_case.value());
    }
  }

286
  void VisitExpr_(const BufferLoadNode *op) final {
287
288
289
    if (!in_if_cond_) {
      has_simt_copy = true;
    }
290
291
292
293
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_simt_copy;
294
  bool in_if_cond_ = false;
295
296
297
298
};

// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
299
public:
300
301
302
303
304
305
  static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
    MbarrierRewriter rewriter;
    rewriter.producer_barrier_idx_ = barrier_id;
    return rewriter(stmt);
  }

306
307
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
308
    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
309
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
310
311
312
313
314
315
316
317
318
319
      Call access_ptr = Downcast<Call>(call->args[2]);
      ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
      call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
    }
    return call;
  }
  PrimExpr producer_barrier_idx_;
};

class ThreadIdxRewriter : public StmtExprMutator {
320
public:
321
322
323
324
325
  static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
    auto rewriter = ThreadIdxRewriter(thread_var, replaced);
    return rewriter(stmt);
  }

326
private:
327
328
329
  ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
      : thread_var_(thread_var), replaced_(replaced) {}

330
  PrimExpr VisitExpr_(const VarNode *var) final {
331
332
333
334
335
336
337
338
339
340
341
    if (var == thread_var_.get()) {
      return replaced_;
    } else {
      return StmtExprMutator::VisitExpr_(var);
    }
  }

  Var thread_var_;
  PrimExpr replaced_;
};

342
343
344
345
346
347
Block MakeGroupBlock(const Stmt &stmt,
                     const Map<String, ObjectRef> &annotations) {
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ stmt,
              /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
              /*annotations=*/annotations);
348
349
350
351
352
353
354
355
356
357
358
  return block;
}

struct OpInfo {
  int group_size, order, stage;
  std::vector<int> group;
};
struct PipelineInfo {
  std::vector<OpInfo> op_infos;

  PipelineInfo() = default;
359
360
  PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
               Array<Integer> stage_info) {
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    int n = static_cast<int>(group_info.size());
    ICHECK(n == static_cast<int>(order_info.size()));
    ICHECK(n == static_cast<int>(stage_info.size()));
    // int cur_id = 0;
    for (int i = 0; i < n; i++) {
      OpInfo op_info;
      op_info.group_size = group_info[i].size();
      for (int j = 0; j < op_info.group_size; j++) {
        op_info.group.push_back(group_info[i][j].as<IntImmNode>()->value);
      }
      op_info.order = order_info[i].as<IntImmNode>()->value;
      op_info.stage = stage_info[i].as<IntImmNode>()->value;
      op_infos.push_back(op_info);
    }
  }

377
  PipelineInfo(const PipelineInfo &other) {
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    for (auto op_info : other.op_infos) {
      op_infos.push_back(op_info);
    }
  }

  std::pair<int, int> FindStmt(int stmt_idx) {
    for (size_t i = 0; i < op_infos.size(); i++) {
      for (size_t j = 0; j < op_infos[i].group.size(); j++) {
        if (op_infos[i].group[j] == stmt_idx) {
          return std::make_pair(i, j);
        }
      }
    }
    return std::make_pair(-1, -1);
  }

  void UpdateOrder(int order) {
    for (int i = 0; i < static_cast<int>(op_infos.size()); i++) {
      if (op_infos[i].order >= order && op_infos[i].order > 0) {
        op_infos[i].order++;
      }
    }
  }

  int SplitOp(int stmt_idx) {
    auto pair = FindStmt(stmt_idx);
    int op_idx = pair.first;
    int inner_idx = pair.second;
    ICHECK(op_idx != -1);
    ICHECK(inner_idx != -1);
    OpInfo half0;
    OpInfo half1;
    // The order to do sync
    int sync_order = op_infos[op_idx].order + 1;
    UpdateOrder(sync_order);

    half0.group_size = inner_idx + 1;
    half0.order = op_infos[op_idx].order;
    half0.stage = op_infos[op_idx].stage;
    for (int i = 0; i <= inner_idx; i++) {
      half0.group.push_back(op_infos[op_idx].group[i]);
    }
    half1.group_size = op_infos[op_idx].group_size - inner_idx - 1;
    half1.order = op_infos[op_idx].order + 2;
    half1.stage = op_infos[op_idx].stage;
    for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) {
      half1.group.push_back(op_infos[op_idx].group[i]);
    }
    op_infos.erase(op_infos.begin() + op_idx);
    if (half0.group_size > 0) {
      op_infos.insert(op_infos.begin() + op_idx, half0);
    }
    if (half1.group_size > 0) {
      UpdateOrder(half1.order);
      op_infos.insert(op_infos.begin() + op_idx + 1, half1);
    }
    return sync_order;
  }

  void PrintPipelineInfo() {
    std::cout << "Print op_infos:" << std::endl;
    for (size_t i = 0; i < op_infos.size(); i++) {
440
441
      std::cout << i << " " << op_infos[i].group_size << " "
                << op_infos[i].order << " " << op_infos[i].stage << std::endl;
442
443
444
445
446
447
    }
    std::cout << "End of print" << std::endl;
  }
};

class GroupOpRewriter : public StmtExprMutator {
448
public:
449
450
  GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}

451
452
private:
  Stmt VisitStmt_(const ForNode *op) final {
453
454
455
456
457
458
459
460
461
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));
    auto original_node = (op->body).as<SeqStmtNode>();
    if (!original_node) {
      return GetRef<For>(op);
    }
    Array<Stmt> new_body;
    int cur_id = 0;
    for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
462
463
      if (pipeline_info_.op_infos[i].group_size == 0)
        continue;
464
      Array<Stmt> block_stmt;
465
466
      for (int j = 0;
           j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
467
        // ICHECK(group_info_[i][j].as<IntImmNode>());
468
469
        // int index =
        // static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
470
471
472
473
474
475
        ICHECK(original_node->seq[cur_id].as<BlockNode>());
        auto block = original_node->seq[cur_id].as<BlockNode>();
        // TODO: handle nested seqstmt
        block_stmt.push_back(block->body);
        cur_id++;
      }
476
477
478
479
      new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                            ? block_stmt[0]
                                            : SeqStmt(std::move(block_stmt)),
                                        annotations));
480
481
482
483
484
485
486
    }
    Array<Integer> order_anno;
    Array<Integer> stage_anno;
    for (auto op_info : pipeline_info_.op_infos) {
      order_anno.push_back(Integer(op_info.order));
      stage_anno.push_back(Integer(op_info.stage));
    }
487
    Map<String, Any> for_annotations = op->annotations;
488
489
490
    for_annotations.erase("tl_pipeline_group");
    for_annotations.Set("software_pipeline_order", order_anno);
    for_annotations.Set("software_pipeline_stage", stage_anno);
491
492
493
494
    For new_for =
        For(op->loop_var, op->min, op->extent, op->kind,
            new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
            op->thread_binding, for_annotations);
495
496
497
498
499
500
    return new_for;
  }

  PipelineInfo pipeline_info_;
};
class WSCodeEmitter : public StmtMutator {
501
public:
502
  WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
503
                Map<Var, Buffer> buffer_data_to_buffer,
504
505
                const WarpSpecializedRoleMarker &marker,
                bool mbarrier_only = false)
506
      : is_emitting_producer_(is_emitting_producer),
507
        buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
508
        thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
509

510
511
private:
  template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
512
    Role role = marker_.GetRole(op);
513
514
515
516
517
    if (mbarrier_only_) {
      if (role != Role::kProducer)
        return StmtMutator::VisitStmt_(op);
    }
    if (role == Role::kBoth) {
518
      return StmtMutator::VisitStmt_(op);
519
    } else if ((role == Role::kProducer) == is_emitting_producer_) {
520
      return GetRef<Stmt>(op);
521
    } else {
522
      return Evaluate(0);
523
    }
524
525
526
  }

  // TODO: only need to add block for ops in the loop
527
  Stmt VisitStmt_(const SeqStmtNode *op) final {
528

529
530
531
532
533
534
535
    bool has_producer = false;
    for (auto stmt : op->seq) {
      if (marker_.GetRole(stmt) == Role::kProducer) {
        has_producer = true;
        break;
      }
    }
536
537
538
539
    bool need_producer_sync =
        has_producer && marker_.GetRole(op) == Role::kBoth;
    if (!need_producer_sync)
      return FilterByRole(op);
540

541
542
    auto seq_transformed =
        op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
543
544

    auto map = ExtractSyncPattern(op->seq);
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    /*
      std::cout << "Print ExtractSyncPattern" << std::endl;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
        << map.release_after[i] << std::endl;
      }
      std::cout << "Print sync pattern" << std::endl;
      for (auto pattern : map.patterns) {
        std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
        std::endl;
      }
      std::cout << "End of ExtractSyncPattern" << std::endl;
      pipeline_info_.PrintPipelineInfo();
    */
559
560
561
562
    Array<Stmt> new_body;
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));

563
    if (is_emitting_producer_) { // producer case
564
565
566
      ProducerTraitsCollector collector;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
567
568
569
570
571
572
573
574
575
576
577
        if (!mbarrier_only_) {
          if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
            continue;
          if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
            block_stmt.push_back(seq_transformed[i]);
            new_body.push_back(MakeGroupBlock(
                block_stmt.size() == 1 ? block_stmt[0]
                                       : SeqStmt(std::move(block_stmt)),
                annotations));
            continue;
          }
578
        }
579

580
        for (int pattern_idx : map.acquire[i]) {
581
          PrimExpr acquire_barrier_id =
582
583
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
584
585
                                ? bitwise_xor(parity_, 1)
                                : parity_;
586
587
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
588
589
590
591
592
593
594
595
        ICHECK(map.release[i].size() > 0);
        for (size_t j = 0; j < map.release[i].size(); j++) {
          int pattern_idx = map.release[i][j];
          PrimExpr release_barrier_id =
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          auto stmt =
              MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
          collector.Collect(stmt);
596
          block_stmt.push_back(stmt);
597
598
          if (collector.HasSimtCopy() > 0) {
            block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
599
          }
600
601
602
603
604
605
606
607
608
609
610
611
          if (map.release_after[i][j]) {
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
          }
          collector.Clear();
          new_body.push_back(MakeGroupBlock(
              block_stmt.size() == 1 ? block_stmt[0]
                                     : SeqStmt(std::move(block_stmt)),
              annotations));
612
613
        }
      }
614
    } else { // consumer case
615
616
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
617
618
        if (marker_.GetRole(op->seq[i]) == Role::kProducer)
          continue;
619
        for (int pattern_idx : map.acquire[i]) {
620
          PrimExpr acquire_barrier_id =
621
622
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
623
624
                                ? bitwise_xor(parity_, 1)
                                : parity_;
625
626
627
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
        block_stmt.push_back(seq_transformed[i]);
628
629
630
631
632
633
634
635
636
637
        for (size_t j = 0; j < map.release[i].size(); j++) {
          if (map.release_after[i][j]) {
            int pattern_idx = map.release[i][j];
            PrimExpr release_barrier_id =
                stage_ + num_barriers_ + num_stages_ * pattern_idx;
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
638
639
          }
        }
640
641
642
643
        new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                              ? block_stmt[0]
                                              : SeqStmt(std::move(block_stmt)),
                                          annotations));
644
645
646
647
      }
      // Filter out the producer stmts
      int cur_id = 0;
      PipelineInfo new_pipeline_info;
648
649
      for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
           i++) {
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        auto op_info = pipeline_info_.op_infos[i];
        bool is_producer = false;
        for (int j = 0; j < op_info.group_size; j++) {
          if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) {
            is_producer = true;
          }
          cur_id++;
        }
        if (is_producer) {
          ICHECK(op_info.group_size == 1);
        } else {
          new_pipeline_info.op_infos.push_back(op_info);
        }
      }
      pipeline_info_ = new_pipeline_info;
    }

    num_barriers_ += map.patterns.size() * num_stages_;

    ICHECK(new_body.size() > 0);
    return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
  }

673
  Stmt VisitStmt_(const ForNode *op) final {
674
675
    int num_stages = 1;
    auto num_stages_anno = op->annotations.Get("num_stages");
676
677
678
    if (num_stages_anno) {
      ICHECK(num_stages_anno->as<IntImmNode>());
      num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
679
680
      ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
    }
681
    loop_stack_.emplace_back(op->loop_var, op->extent);
682
683
684
685

    Array<Array<Integer>> group_info_array;
    Array<Integer> order_info_array;
    Array<Integer> stage_info_array;
686

687
    auto group_anno = op->annotations.Get("tl_pipeline_group");
688
689
    if (group_anno) {
      group_info_array = Downcast<Array<Array<Integer>>>(group_anno.value());
690
691
    }
    auto order_anno = op->annotations.Get("tl_pipeline_order");
692
693
    if (order_anno) {
      order_info_array = Downcast<Array<Integer>>(order_anno.value());
694
695
    }
    auto stage_anno = op->annotations.Get("tl_pipeline_stage");
696
697
    if (stage_anno) {
      stage_info_array = Downcast<Array<Integer>>(stage_anno.value());
698
699
    }

700
701
    PipelineInfo pipeline_info(group_info_array, order_info_array,
                               stage_info_array);
702
    if (pipeline_info.op_infos.size() > 0) {
703
704
      ICHECK(pipeline_info_.op_infos.size() == 0)
          << "Nested pipeline not supported.";
705
706
707
708
709
710
711
712
713
    }

    PrimExpr parity_before = std::move(parity_);
    PrimExpr stage_before = std::move(stage_);
    int num_stages_before = num_stages_;
    PipelineInfo pipeline_info_before = pipeline_info_;

    num_stages_ = num_stages;
    pipeline_info_ = pipeline_info;
714
715
716
717
718
719
720
721
    PrimExpr linear_index = loop_stack_[0].first;
    for (size_t i = 1; i < loop_stack_.size(); ++i) {
      linear_index =
          linear_index * loop_stack_[i].second + loop_stack_[i].first;
    }
    stage_ = FloorMod(linear_index, num_stages);
    parity_ = FloorMod(
        parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
722
723
724
725

    auto result = FilterByRole(op);

    Stmt grouped_for_node;
726
727
    if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 &&
        !is_emitting_producer_) {
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
      GroupOpRewriter group_op_rewriter(pipeline_info_);
      auto for_node = Downcast<For>(result);
      grouped_for_node = group_op_rewriter(for_node);
    }

    parity_ = std::move(parity_before);
    stage_ = std::move(stage_before);
    num_stages_ = num_stages_before;
    pipeline_info_ = pipeline_info_before;

    // remove pipeline annotation
    auto for_node = result.as<For>();
    if (result.as<ForNode>()) {
      auto for_node = Downcast<For>(result);
      for_node.CopyOnWrite()->annotations.erase("num_stages");
      if (is_emitting_producer_ || group_info_array.size() == 0) {
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
      }
747
      if (is_emitting_producer_ || !group_anno ||
748
          group_info_array.size() == 0) {
749
        loop_stack_.pop_back();
750
751
        return for_node;
      }
752
      loop_stack_.pop_back();
753
754
      return grouped_for_node;
    }
755
    loop_stack_.pop_back();
756
757
758
    return result;
  }

759
760
761
762
763
764
765
  Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BlockNode *op) final {
766
767
768
    ICHECK(0);
    return Stmt();
  }
769
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
770
771
772
773
774
775
776
777
778
    ICHECK(0);
    return Stmt();
  }

  struct SyncPattern {
    int release_idx, acquire_idx;
  };

  struct SyncPatternMap {
779
780
781
    std::vector<std::vector<int>> acquire;
    std::vector<std::vector<int>> release;
    std::vector<std::vector<bool>> release_after;
782
    std::vector<SyncPattern> patterns;
783
784
785
786
787
788
789
790
791
792

    void resize(size_t n) {
      acquire.resize(n);
      release.resize(n);
      release_after.resize(n);
    }

    bool is_loop_dependency(int pattern_idx) {
      return patterns[pattern_idx].release_idx >
             patterns[pattern_idx].acquire_idx;
793
794
795
    }
  };

796
797
798
  std::vector<SyncPattern>
  CreateBaseSyncPairs(Array<Stmt> seq_stmt,
                      const std::vector<bool> &is_producer) {
799
    const int n = seq_stmt.size();
800
    std::vector<std::set<const BufferNode *>> reads, writes;
801
802
803
    reads.reserve(n);
    writes.reserve(n);
    for (int i = 0; i < n; i++) {
804
805
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"",
806
807
                  /*body*/ seq_stmt[i]);
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
808
      std::set<const BufferNode *> read_set, write_set;
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
      for (auto region : access[0]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          read_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          read_set.insert(region->buffer.get());
        }
      }
      for (auto region : access[1]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          write_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          write_set.insert(region->buffer.get());
        }
      }
825
826
827
828
      reads.push_back(std::move(read_set));
      writes.push_back(std::move(write_set));
    }

829
830
    auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
                           const std::set<const BufferNode *> &rhs) {
831
      for (auto ptr : lhs)
832
833
        if (rhs.count(ptr))
          return true;
834
835
836
837
838
839
840
841
842
      return false;
    };

    std::vector<SyncPattern> sync_patterns;
    // producer_release consumer_acquire,
    // inject before the first consumer stmt for each producer
    for (int i = 0; i < n; i++) {
      for (int j = i + 1; j < n; j++) {
        if (is_producer[i] != is_producer[j] &&
843
844
            (intersect_fn(writes[i], reads[j]) ||
             intersect_fn(reads[i], writes[j]))) {
845
846
847
848
849
850
851
852
853
854
855
856
857
858
          sync_patterns.push_back({i, j});
          break;
        }
      }
    }

    // consumer_release producer_acquire
    // valid when is_loop is true
    // inject before the earliest producer stmt for each consumer
    bool in_loop = !is_zero(parity_);
    if (in_loop) {
      for (int i = 0; i < n; i++) {
        for (int j = 0; j < i; j++) {
          if (is_producer[i] != is_producer[j] &&
859
860
              (intersect_fn(writes[i], reads[j]) ||
               intersect_fn(reads[i], writes[j]))) {
861
862
863
864
865
866
867
868
869
870
            sync_patterns.push_back({i, j});
            break;
          }
        }
      }
    }

    return sync_patterns;
  }

871
872
873
  static std::vector<SyncPattern>
  RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
                           const std::vector<bool> &is_producer) {
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
    /*
      Simplify multiple release-acquire pairs into one
      ------------------
        Produce(A)
        Produce(B)
        Consume(A, B)
      ------------------
      [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)]

      Or
      ------------------
        Produce(A, B)
        Consume(A)
        Consume(B)
      ------------------
      [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)]
    */
    int M = sync_patterns.size();
    std::vector<bool> removed(M, false);
    for (int i = 0; i < M; i++) {
      for (int j = 0; j < M; j++) {
        if (is_producer[sync_patterns[i].acquire_idx] ==
                is_producer[sync_patterns[j].acquire_idx] &&
            sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx &&
            sync_patterns[i].release_idx < sync_patterns[j].release_idx)
          removed[i] = true;
      }
    }

    std::vector<SyncPattern> sync_pattern_cleaned;
    sync_pattern_cleaned.reserve(M);
    for (int i = 0; i < M; i++)
906
907
      if (!removed[i])
        sync_pattern_cleaned.push_back(sync_patterns[i]);
908
909
910
911
912
913
914
915
916
917
918
919
920

    return sync_pattern_cleaned;
  }

  SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) {
    size_t num_stmts = seq_stmt.size();
    std::vector<bool> is_producer;
    is_producer.reserve(num_stmts);
    for (auto stmt : seq_stmt) {
      is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer);
    }

    auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
921
922
    auto sync_patterns =
        RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
923
924

    // for (auto pattern : sync_patterns) {
925
926
    //   std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
    //   std::endl;
927
928
929
    // }

    SyncPatternMap map;
930
    map.resize(num_stmts);
931
    map.patterns = sync_patterns;
932

933
    for (size_t i = 0; i < sync_patterns.size(); i++) {
934
935
936
937
938
939
      int acquire_idx = sync_patterns[i].acquire_idx;
      int release_idx = sync_patterns[i].release_idx;

      map.acquire[acquire_idx].push_back(i);
      map.release[release_idx].push_back(i);
      map.release_after[release_idx].push_back(true);
940
941
    }

942
    std::vector<int> cur_consumer_barrier, cur_producer_barrier;
943
944
    for (int i = num_stmts - 1; i >= 0; i--) {
      if (is_producer[i]) {
945
946
947
948
949
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_producer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
950
        } else {
951
952
953
          for (auto pattern_idx : map.release[i]) {
            cur_producer_barrier.push_back(pattern_idx);
          }
954
955
        }
      } else {
956
957
958
959
960
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_consumer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
961
        } else {
962
963
964
          for (auto pattern_idx : map.release[i]) {
            cur_consumer_barrier.push_back(pattern_idx);
          }
965
966
967
968
969
970
971
972
973
        }
      }
    }
    return map;
  }

  const bool is_emitting_producer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  std::unordered_set<int> released_barrier_;
974
  const WarpSpecializedRoleMarker &marker_;
975
976
977
978
979

  int num_barriers_ = 0;
  PrimExpr parity_ = 0;
  PrimExpr stage_ = 0;
  int num_stages_ = 1;
980
  std::vector<std::pair<Var, PrimExpr>> loop_stack_;
981
  Var thread_var_;
982
  bool mbarrier_only_ = false;
983
984
985
986
  PipelineInfo pipeline_info_;
  friend class WarpSpecializedRewriter;
};

987
988
989
990
991
class SetMaxNRegCollector : public StmtExprVisitor {
public:
  static Array<IntImm> Collect(const PrimFunc &f) {
    SetMaxNRegCollector collector;
    collector(f->body);
992
993
994
995
    return collector.has_no_set_max_nreg_
               ? Array<IntImm>({IntImm(DataType::Int(32), -1),
                                IntImm(DataType::Int(32), -1)})
               : collector.nreg_;
996
997
998
999
1000
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
1001
      if (call->op.same_as(set_max_nreg())) {
1002
1003
1004
1005
1006
1007
1008
1009
1010
        int reg_hint = call->args[0].as<IntImmNode>()->value;
        int is_inc = call->args[1].as<IntImmNode>()->value;
        ICHECK(reg_hint <= 240 && reg_hint >= 24)
            << "Invalid reg hint: " << reg_hint;
        ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;

        // producer should decrease register hint while consumer should increase
        // register hint
        nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
1011
      } else if (call->op.same_as(no_set_max_nreg())) {
1012
        has_no_set_max_nreg_ = true;
1013
1014
1015
1016
1017
1018
1019
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
                      IntImm(DataType::Int(32), 0)};
1020
  bool has_no_set_max_nreg_ = false;
1021
1022
};

1023
class WarpSpecializedRewriter : public StmtExprMutator {
1024
public:
1025
1026
1027
  WarpSpecializedRewriter(bool disable_warp_specialized)
      : disable_warp_specialized_(disable_warp_specialized) {}
  static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized) {
1028
1029
1030
    // Check if function only uses threadIdx.x before proceeding
    if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
      LOG(WARNING) << "WarpSpecialize will be disabled because the program "
1031
                      "uses thread tags other than threadIdx.x."
1032
1033
1034
1035
1036
1037
                   << "If you want to use warp specialization, please refactor "
                      "your program to use threadIdx.x only";
      // Return original function unchanged if other thread tags are found
      return f;
    }

1038
    auto T = WarpSpecializedRewriter(disable_warp_specialized);
1039
    T.nreg_ = SetMaxNRegCollector::Collect(f);
1040
    T.buffer_lca_ = DetectBufferAccessLCA(f);
1041
1042
    for (auto [buffer, _] : T.buffer_lca_)
      T.buffer_data_to_buffer_.Set(buffer->data, buffer);
1043
1044
1045
1046
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

1047
1048
private:
  Stmt VisitStmt_(const AttrStmtNode *op) final {
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

1066
1067
  Stmt VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
1068
1069
      if (call->op.same_as(set_max_nreg()) ||
          call->op.same_as(no_set_max_nreg())) {
1070
1071
1072
1073
1074
1075
        return Evaluate(0);
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

1076
1077
1078
1079
  // If users define a thread binding, we will replace the thread binding with
  // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
  // the same as the thread extent
  Stmt VisitStmt_(const ForNode *op) final {
1080
1081
1082
1083
1084
1085
1086
    ICHECK(thread_iv_.defined());
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (for_node->kind == ForKind::kThreadBinding) {
      ICHECK(for_node->thread_binding.defined());
      String thread_tag = for_node->thread_binding.value()->thread_tag;
      ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
      Var thread_iv = Downcast<Var>(for_node->loop_var);
1087
1088
      Stmt new_body =
          ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
1089
1090
1091
1092
1093
      return new_body;
    }
    return for_node;
  }

1094
1095
1096
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
1097
1098
1099
1100
1101
1102
    if (!thread_iv_.defined()) {
      return block_realize;
    }

    Block block = block_realize->block;
    WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
1103
    marker.Prepare(block);
1104
1105
1106
1107
1108
1109
    marker(block);
    if (!marker.HasProducer()) {
      // Cannot detect any producer here, directly return.
      return block_realize;
    }

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    if (disable_warp_specialized_) {
      WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
                                     marker, true);
      auto code = mbarrier_emitter(block->body);
      int num_barriers = mbarrier_emitter.num_barriers_;
      Array<PrimExpr> barrier_num_threads;
      barrier_num_threads.reserve(num_barriers);
      PrimExpr arrive_thread_count = thread_iv_->dom->extent;
      for (int i = 0; i < num_barriers; i++) {
        barrier_num_threads.push_back(arrive_thread_count);
      }
      Stmt init_barrier = Evaluate(Call(
1122
          DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1123
1124
1125
1126
      block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
      block_realize.CopyOnWrite()->block = block;
      return block_realize;
    }
1127
1128
1129
1130
1131
1132
1133
    WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
    WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
    Stmt producer_code = producer(block->body);
    Stmt consumer_code = consumer(block->body);
    PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
    PrimExpr producer_thread_extent = thread_iv_->dom->extent;
    // Need one warp-group for bulk-copy only case
1134
1135
    if (!marker.HasSimtCopy())
      producer_thread_extent = 128;
1136
1137

    // TODO: estimate the correct reg usage.
1138
1139
1140
    int dec_reg = nreg_[0].as<IntImmNode>()->value;
    int inc_reg = nreg_[1].as<IntImmNode>()->value;

1141
1142
    auto inc_reg_stmt = Evaluate(0);
    auto dec_reg_stmt = Evaluate(0);
1143
    if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
1144
      inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
1145
                                   {inc_reg == 0 ? 240 : inc_reg, 1}));
1146
      dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
1147
1148
                                   {dec_reg == 0 ? 24 : dec_reg, 0}));
    }
1149
1150
1151
1152

    producer_code = SeqStmt({dec_reg_stmt, producer_code});
    consumer_code = SeqStmt({inc_reg_stmt, consumer_code});

1153
1154
1155
    producer_code =
        ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
                                   thread_iv_->var - consumer_thread_extent);
1156
1157
1158
1159
1160
1161
1162
1163
1164
    updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
    need_update_thread_extent_ = true;

    ICHECK(producer.num_barriers_ == consumer.num_barriers_)
        << producer.num_barriers_ << " " << consumer.num_barriers_;
    int num_barriers = consumer.num_barriers_;
    Array<PrimExpr> barrier_num_threads;
    barrier_num_threads.reserve(num_barriers);
    for (int i = 0; i < num_barriers; i++) {
1165
1166
1167
      PrimExpr arrive_thread_count = producer.released_barrier_.count(i)
                                         ? producer_thread_extent
                                         : consumer_thread_extent;
1168
1169
1170
      barrier_num_threads.push_back(arrive_thread_count);
    }

1171
    Stmt init_barrier = Evaluate(Call(
1172
        DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1173
1174
    Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
                           producer_code, consumer_code);
1175
    // Add an attr here to handle the partial thread count in ThreadSync pass.
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
                                  Downcast<IntImm>(consumer_thread_extent)};
    body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body);

    block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

  WarpSpecializedRewriter() = default;

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
1193
  bool disable_warp_specialized_ = false;
1194
  Array<IntImm> nreg_;
1195
1196
};

1197
1198
1199
1200
1201
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
  static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
    WarpSpecializedDetector detector;
    detector.VisitStmt(stmt);
1202
1203
    return detector.has_warp_specialization_ ||
           (detector.has_tma_op_ && detector.has_mbarrier_op_);
1204
1205
1206
1207
1208
  }

  WarpSpecializedDetector() {
    has_tma_op_ = false;
    has_mbarrier_op_ = false;
1209
    has_warp_specialization_ = false;
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(create_list_of_mbarrier()) ||
          call->op.same_as(mbarrier_wait_parity()) ||
          call->op.same_as(builtin::ptx_arrive_barrier()) ||
          call->op.same_as(builtin::ptx_cp_async_barrier())) {
        has_mbarrier_op_ = true;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
        op->op.same_as(set_max_nreg())) {
      has_tma_op_ = true;
    }
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

1233
  void VisitStmt_(const AttrStmtNode *op) final {
1234
1235
1236
1237
    if (op->attr_key == "warp_specialize" &&
        op->value.as<IntImmNode>()->value == 1) {
      has_warp_specialization_ = true;
    }
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

1248
  bool has_tma_op_{false};
1249
  IterVar thread_var_;
1250
  bool has_mbarrier_op_{false};
1251
  bool has_warp_specialization_{false};
1252
1253
};

1254
1255
1256
1257
using namespace tir::transform;

tvm::transform::Pass WarpSpecialized() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1258
1259
    bool disable_warp_specialized =
        ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
1260
1261
1262
1263
1264
1265
    bool warp_specialized = WarpSpecializedDetector::Detect(f->body);

    if (!warp_specialized) {
      return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
    }
    return f;
1266
1267
1268
1269
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}

1270
1271
1272
1273
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
1274

1275
1276
} // namespace tl
} // namespace tvm