warp_specialized_rewriter.cc 41.2 KB
Newer Older
1
/*!
2
 * \file warp_specialized_rewriter.cc
3
4
5
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

6
#include "arith/ir_visitor_with_analyzer.h"
7
#include "tir/analysis/var_use_def_analysis.h"
8
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
13
14
15
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
16
#include "./common/collector.h"
17
18
19
20
21

namespace tvm {
namespace tl {

using namespace tir;
22
using arith::IRVisitorWithAnalyzer;
23
24
25

enum class Role { kConsumer, kProducer, kBoth };

26
27
28
29
30
class TMAFinder : public StmtExprVisitor {
public:
  void clear() { has_tma_load_ = false; }

  void VisitExpr_(const CallNode *call) final {
31
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
      has_tma_load_ = true;
    }
  }

  bool has_tma_load_ = false;
};

class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
  auto FindProducerusedBuffer(Stmt stmt) {
    VisitStmt(stmt);
    return used_in_producer_cond_;
  }

  void InsertBuffer(const PrimExpr &expr) {
    // Find the buffer that is used in the condition
    VarUseDefAnalyzer usage(Array<Var>{});
    usage(expr);
    for (const auto &buffer : usage.buffer_use_count_) {
      used_in_producer_cond_.insert(buffer.first);
    }
    for (const auto &buffer : used_in_producer_cond_) {
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    TMAFinder tma_finder;
    tma_finder(op->then_case);
    if (op->else_case.defined()) {
      tma_finder(op->else_case.value());
    }
    if (tma_finder.has_tma_load_) {
      InsertBuffer(op->condition);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
    TMAFinder tma_finder;
    tma_finder(op->body);
    if (tma_finder.has_tma_load_) {
      InsertBuffer(op->min);
      InsertBuffer(op->extent);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

private:
  std::unordered_set<const BufferNode *> used_in_producer_cond_;
};

83
class WarpSpecializedRoleMarker : public StmtVisitor {
84
public:
85
86
87
  WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

88
89
90
91
92
  void Prepare(const Stmt &stmt) {
    ProducerUsedBufferFinder finder;
    used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt);
  }

93
  Role GetRole(const StmtNode *stmt) const {
94
95
96
97
98
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

99
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
100

101
  void VisitStmt_(const EvaluateNode *op) final {
102
103
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
104
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
105
106
107
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
108
109
110
      if (call->op.same_as(loop_break())) {
        role = Role::kBoth;
      }
111
112
113
114
    }
    SetRole(op, role);
  }

115
116
117
  void VisitStmt_(const BufferStoreNode *op) final {
    bool is_shared_store =
        op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
118
119
120
121
    if (used_in_producer_cond_.count(op->buffer.get())) {
      SetRole(op, Role::kBoth);
      return;
    }
122
123
124
125
126
127
128
129
130
131
132
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
133
134
    if (reads.empty())
      role = Role::kConsumer;
135
136
137
138
139
140
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
141
142
    if (role == Role::kProducer)
      has_simt_copy_ = true;
143
144
145
    SetRole(op, role);
  }

146
  void VisitStmt_(const SeqStmtNode *op) final {
147
148
149
150
151
152
153
154
155
156
157
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

158
  void VisitStmt_(const IfThenElseNode *op) final {
159
160
161
162
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
163
164
      if (role != role_else)
        role = Role::kBoth;
165
166
167
168
    }
    SetRole(op, role);
  }

169
  void VisitStmt_(const BlockRealizeNode *op) final {
170
171
172
173
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

174
175
176
177
178
179
  void VisitStmt_(const AllocateNode *op) final {
    StmtVisitor::VisitStmt_(op);
    Role role = Role::kConsumer;
    SetRole(op, role);
  }

180
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
181
182
183
184
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

185
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
186
  void VisitStmt_(const WhileNode *op) final { HandleBodyStmt(op); }
187
188
189
190
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
191
192
193
194
195

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

196
197
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
198
  Map<Var, Buffer> buffer_data_to_buffer_;
199
  std::unordered_map<const StmtNode *, Role> map_;
200
201
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
202
  std::unordered_set<const BufferNode *> used_in_producer_cond_;
203
204
205
};

static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
206
  return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
207
208
209
}

static Stmt makeArriveBarrier(PrimExpr barrier_id) {
210
211
  auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
                   {makeGetBarrier(barrier_id)});
212
213
214
215
  return Evaluate(call);
}

static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
216
217
  auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
                   {makeGetBarrier(barrier_id)});
218
219
220
221
  return Evaluate(call);
}

static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
222
  auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
223
                   {makeGetBarrier(barrier_id), parity});
224
225
226
227
  return Evaluate(call);
}

class ProducerTraitsCollector : public StmtExprVisitor {
228
public:
229
230
  ProducerTraitsCollector() { Clear(); }

231
  void Clear() { has_simt_copy = false; }
232
233
234
235
236

  void Collect(Stmt stmt) { VisitStmt(stmt); }

  bool HasSimtCopy() { return has_simt_copy; }

237
private:
238
239
240
241
242
243
244
245
246
247
248
249
  void VisitStmt_(const IfThenElseNode *op) final {
    bool old_in_if_cond = in_if_cond_;
    in_if_cond_ = true;
    VisitExpr(op->condition);
    in_if_cond_ = old_in_if_cond;

    VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      VisitStmt(op->else_case.value());
    }
  }

250
  void VisitExpr_(const BufferLoadNode *op) final {
251
252
253
    if (!in_if_cond_) {
      has_simt_copy = true;
    }
254
255
256
257
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_simt_copy;
258
  bool in_if_cond_ = false;
259
260
261
262
};

// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
263
public:
264
265
266
267
268
269
  static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
    MbarrierRewriter rewriter;
    rewriter.producer_barrier_idx_ = barrier_id;
    return rewriter(stmt);
  }

270
271
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
272
    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
273
    if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
274
275
276
277
278
279
280
281
282
283
      Call access_ptr = Downcast<Call>(call->args[2]);
      ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
      call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
    }
    return call;
  }
  PrimExpr producer_barrier_idx_;
};

class ThreadIdxRewriter : public StmtExprMutator {
284
public:
285
286
287
288
289
  static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
    auto rewriter = ThreadIdxRewriter(thread_var, replaced);
    return rewriter(stmt);
  }

290
private:
291
292
293
  ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
      : thread_var_(thread_var), replaced_(replaced) {}

294
  PrimExpr VisitExpr_(const VarNode *var) final {
295
296
297
298
299
300
301
302
303
304
305
    if (var == thread_var_.get()) {
      return replaced_;
    } else {
      return StmtExprMutator::VisitExpr_(var);
    }
  }

  Var thread_var_;
  PrimExpr replaced_;
};

306
307
308
309
310
311
Block MakeGroupBlock(const Stmt &stmt,
                     const Map<String, ObjectRef> &annotations) {
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ stmt,
              /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
              /*annotations=*/annotations);
312
313
314
315
316
317
318
319
320
321
322
  return block;
}

struct OpInfo {
  int group_size, order, stage;
  std::vector<int> group;
};
struct PipelineInfo {
  std::vector<OpInfo> op_infos;

  PipelineInfo() = default;
323
324
  PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
               Array<Integer> stage_info) {
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    int n = static_cast<int>(group_info.size());
    ICHECK(n == static_cast<int>(order_info.size()));
    ICHECK(n == static_cast<int>(stage_info.size()));
    // int cur_id = 0;
    for (int i = 0; i < n; i++) {
      OpInfo op_info;
      op_info.group_size = group_info[i].size();
      for (int j = 0; j < op_info.group_size; j++) {
        op_info.group.push_back(group_info[i][j].as<IntImmNode>()->value);
      }
      op_info.order = order_info[i].as<IntImmNode>()->value;
      op_info.stage = stage_info[i].as<IntImmNode>()->value;
      op_infos.push_back(op_info);
    }
  }

341
  PipelineInfo(const PipelineInfo &other) {
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    for (auto op_info : other.op_infos) {
      op_infos.push_back(op_info);
    }
  }

  std::pair<int, int> FindStmt(int stmt_idx) {
    for (size_t i = 0; i < op_infos.size(); i++) {
      for (size_t j = 0; j < op_infos[i].group.size(); j++) {
        if (op_infos[i].group[j] == stmt_idx) {
          return std::make_pair(i, j);
        }
      }
    }
    return std::make_pair(-1, -1);
  }

  void UpdateOrder(int order) {
    for (int i = 0; i < static_cast<int>(op_infos.size()); i++) {
      if (op_infos[i].order >= order && op_infos[i].order > 0) {
        op_infos[i].order++;
      }
    }
  }

  int SplitOp(int stmt_idx) {
    auto pair = FindStmt(stmt_idx);
    int op_idx = pair.first;
    int inner_idx = pair.second;
    ICHECK(op_idx != -1);
    ICHECK(inner_idx != -1);
    OpInfo half0;
    OpInfo half1;
    // The order to do sync
    int sync_order = op_infos[op_idx].order + 1;
    UpdateOrder(sync_order);

    half0.group_size = inner_idx + 1;
    half0.order = op_infos[op_idx].order;
    half0.stage = op_infos[op_idx].stage;
    for (int i = 0; i <= inner_idx; i++) {
      half0.group.push_back(op_infos[op_idx].group[i]);
    }
    half1.group_size = op_infos[op_idx].group_size - inner_idx - 1;
    half1.order = op_infos[op_idx].order + 2;
    half1.stage = op_infos[op_idx].stage;
    for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) {
      half1.group.push_back(op_infos[op_idx].group[i]);
    }
    op_infos.erase(op_infos.begin() + op_idx);
    if (half0.group_size > 0) {
      op_infos.insert(op_infos.begin() + op_idx, half0);
    }
    if (half1.group_size > 0) {
      UpdateOrder(half1.order);
      op_infos.insert(op_infos.begin() + op_idx + 1, half1);
    }
    return sync_order;
  }

  void PrintPipelineInfo() {
    std::cout << "Print op_infos:" << std::endl;
    for (size_t i = 0; i < op_infos.size(); i++) {
404
405
      std::cout << i << " " << op_infos[i].group_size << " "
                << op_infos[i].order << " " << op_infos[i].stage << std::endl;
406
407
408
409
410
411
    }
    std::cout << "End of print" << std::endl;
  }
};

class GroupOpRewriter : public StmtExprMutator {
412
public:
413
414
  GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}

415
416
private:
  Stmt VisitStmt_(const ForNode *op) final {
417
418
419
420
421
422
423
424
425
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));
    auto original_node = (op->body).as<SeqStmtNode>();
    if (!original_node) {
      return GetRef<For>(op);
    }
    Array<Stmt> new_body;
    int cur_id = 0;
    for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
426
427
      if (pipeline_info_.op_infos[i].group_size == 0)
        continue;
428
      Array<Stmt> block_stmt;
429
430
      for (int j = 0;
           j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
431
        // ICHECK(group_info_[i][j].as<IntImmNode>());
432
433
        // int index =
        // static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
434
435
436
437
438
439
        ICHECK(original_node->seq[cur_id].as<BlockNode>());
        auto block = original_node->seq[cur_id].as<BlockNode>();
        // TODO: handle nested seqstmt
        block_stmt.push_back(block->body);
        cur_id++;
      }
440
441
442
443
      new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                            ? block_stmt[0]
                                            : SeqStmt(std::move(block_stmt)),
                                        annotations));
444
445
446
447
448
449
450
    }
    Array<Integer> order_anno;
    Array<Integer> stage_anno;
    for (auto op_info : pipeline_info_.op_infos) {
      order_anno.push_back(Integer(op_info.order));
      stage_anno.push_back(Integer(op_info.stage));
    }
451
    Map<String, Any> for_annotations = op->annotations;
452
453
454
    for_annotations.erase("tl_pipeline_group");
    for_annotations.Set("software_pipeline_order", order_anno);
    for_annotations.Set("software_pipeline_stage", stage_anno);
455
456
457
458
    For new_for =
        For(op->loop_var, op->min, op->extent, op->kind,
            new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
            op->thread_binding, for_annotations);
459
460
461
462
463
464
    return new_for;
  }

  PipelineInfo pipeline_info_;
};
class WSCodeEmitter : public StmtMutator {
465
public:
466
  WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
467
                Map<Var, Buffer> buffer_data_to_buffer,
468
469
                const WarpSpecializedRoleMarker &marker,
                bool mbarrier_only = false)
470
      : is_emitting_producer_(is_emitting_producer),
471
        buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
472
        thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
473

474
475
private:
  template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
476
    Role role = marker_.GetRole(op);
477
478
479
480
481
    if (mbarrier_only_) {
      if (role != Role::kProducer)
        return StmtMutator::VisitStmt_(op);
    }
    if (role == Role::kBoth) {
482
      return StmtMutator::VisitStmt_(op);
483
    } else if ((role == Role::kProducer) == is_emitting_producer_) {
484
      return GetRef<Stmt>(op);
485
    } else {
486
      return Evaluate(0);
487
    }
488
489
490
  }

  // TODO: only need to add block for ops in the loop
491
  Stmt VisitStmt_(const SeqStmtNode *op) final {
492

493
494
495
496
497
498
499
    bool has_producer = false;
    for (auto stmt : op->seq) {
      if (marker_.GetRole(stmt) == Role::kProducer) {
        has_producer = true;
        break;
      }
    }
500
501
502
503
    bool need_producer_sync =
        has_producer && marker_.GetRole(op) == Role::kBoth;
    if (!need_producer_sync)
      return FilterByRole(op);
504

505
506
    auto seq_transformed =
        op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
507
508

    auto map = ExtractSyncPattern(op->seq);
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    /*
      std::cout << "Print ExtractSyncPattern" << std::endl;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
        << map.release_after[i] << std::endl;
      }
      std::cout << "Print sync pattern" << std::endl;
      for (auto pattern : map.patterns) {
        std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
        std::endl;
      }
      std::cout << "End of ExtractSyncPattern" << std::endl;
      pipeline_info_.PrintPipelineInfo();
    */
523
524
525
526
    Array<Stmt> new_body;
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));

527
    if (is_emitting_producer_) { // producer case
528
529
530
      ProducerTraitsCollector collector;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
531
532
533
534
535
536
537
538
539
540
541
        if (!mbarrier_only_) {
          if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
            continue;
          if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
            block_stmt.push_back(seq_transformed[i]);
            new_body.push_back(MakeGroupBlock(
                block_stmt.size() == 1 ? block_stmt[0]
                                       : SeqStmt(std::move(block_stmt)),
                annotations));
            continue;
          }
542
        }
543

544
        for (int pattern_idx : map.acquire[i]) {
545
          PrimExpr acquire_barrier_id =
546
547
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
548
549
                                ? bitwise_xor(parity_, 1)
                                : parity_;
550
551
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
552
553
554
555
556
557
558
559
        ICHECK(map.release[i].size() > 0);
        for (size_t j = 0; j < map.release[i].size(); j++) {
          int pattern_idx = map.release[i][j];
          PrimExpr release_barrier_id =
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          auto stmt =
              MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
          collector.Collect(stmt);
560
          block_stmt.push_back(stmt);
561
562
          if (collector.HasSimtCopy() > 0) {
            block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
563
          }
564
565
566
567
568
569
570
571
572
573
574
575
          if (map.release_after[i][j]) {
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
          }
          collector.Clear();
          new_body.push_back(MakeGroupBlock(
              block_stmt.size() == 1 ? block_stmt[0]
                                     : SeqStmt(std::move(block_stmt)),
              annotations));
576
577
        }
      }
578
    } else { // consumer case
579
580
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
581
582
        if (marker_.GetRole(op->seq[i]) == Role::kProducer)
          continue;
583
        for (int pattern_idx : map.acquire[i]) {
584
          PrimExpr acquire_barrier_id =
585
586
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
587
588
                                ? bitwise_xor(parity_, 1)
                                : parity_;
589
590
591
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
        block_stmt.push_back(seq_transformed[i]);
592
593
594
595
596
597
598
599
600
601
        for (size_t j = 0; j < map.release[i].size(); j++) {
          if (map.release_after[i][j]) {
            int pattern_idx = map.release[i][j];
            PrimExpr release_barrier_id =
                stage_ + num_barriers_ + num_stages_ * pattern_idx;
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
602
603
          }
        }
604
605
606
607
        new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                              ? block_stmt[0]
                                              : SeqStmt(std::move(block_stmt)),
                                          annotations));
608
609
610
611
      }
      // Filter out the producer stmts
      int cur_id = 0;
      PipelineInfo new_pipeline_info;
612
613
      for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
           i++) {
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        auto op_info = pipeline_info_.op_infos[i];
        bool is_producer = false;
        for (int j = 0; j < op_info.group_size; j++) {
          if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) {
            is_producer = true;
          }
          cur_id++;
        }
        if (is_producer) {
          ICHECK(op_info.group_size == 1);
        } else {
          new_pipeline_info.op_infos.push_back(op_info);
        }
      }
      pipeline_info_ = new_pipeline_info;
    }

    num_barriers_ += map.patterns.size() * num_stages_;

    ICHECK(new_body.size() > 0);
    return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
  }

637
  Stmt VisitStmt_(const ForNode *op) final {
638
639
    int num_stages = 1;
    auto num_stages_anno = op->annotations.Get("num_stages");
640
641
642
    if (num_stages_anno) {
      ICHECK(num_stages_anno->as<IntImmNode>());
      num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
643
644
      ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
    }
645
    loop_stack_.emplace_back(op->loop_var, op->extent);
646
647
648
649

    Array<Array<Integer>> group_info_array;
    Array<Integer> order_info_array;
    Array<Integer> stage_info_array;
650

651
    auto group_anno = op->annotations.Get("tl_pipeline_group");
652
653
    if (group_anno) {
      group_info_array = Downcast<Array<Array<Integer>>>(group_anno.value());
654
655
    }
    auto order_anno = op->annotations.Get("tl_pipeline_order");
656
657
    if (order_anno) {
      order_info_array = Downcast<Array<Integer>>(order_anno.value());
658
659
    }
    auto stage_anno = op->annotations.Get("tl_pipeline_stage");
660
661
    if (stage_anno) {
      stage_info_array = Downcast<Array<Integer>>(stage_anno.value());
662
663
    }

664
665
    PipelineInfo pipeline_info(group_info_array, order_info_array,
                               stage_info_array);
666
    if (pipeline_info.op_infos.size() > 0) {
667
668
      ICHECK(pipeline_info_.op_infos.size() == 0)
          << "Nested pipeline not supported.";
669
670
671
672
673
674
675
676
677
    }

    PrimExpr parity_before = std::move(parity_);
    PrimExpr stage_before = std::move(stage_);
    int num_stages_before = num_stages_;
    PipelineInfo pipeline_info_before = pipeline_info_;

    num_stages_ = num_stages;
    pipeline_info_ = pipeline_info;
678
679
680
681
682
683
684
685
    PrimExpr linear_index = loop_stack_[0].first;
    for (size_t i = 1; i < loop_stack_.size(); ++i) {
      linear_index =
          linear_index * loop_stack_[i].second + loop_stack_[i].first;
    }
    stage_ = FloorMod(linear_index, num_stages);
    parity_ = FloorMod(
        parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
686
687
688
689

    auto result = FilterByRole(op);

    Stmt grouped_for_node;
690
691
    if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 &&
        !is_emitting_producer_) {
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
      GroupOpRewriter group_op_rewriter(pipeline_info_);
      auto for_node = Downcast<For>(result);
      grouped_for_node = group_op_rewriter(for_node);
    }

    parity_ = std::move(parity_before);
    stage_ = std::move(stage_before);
    num_stages_ = num_stages_before;
    pipeline_info_ = pipeline_info_before;

    // remove pipeline annotation
    auto for_node = result.as<For>();
    if (result.as<ForNode>()) {
      auto for_node = Downcast<For>(result);
      for_node.CopyOnWrite()->annotations.erase("num_stages");
      if (is_emitting_producer_ || group_info_array.size() == 0) {
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
      }
711
      if (is_emitting_producer_ || !group_anno ||
712
          group_info_array.size() == 0) {
713
        loop_stack_.pop_back();
714
715
        return for_node;
      }
716
      loop_stack_.pop_back();
717
718
      return grouped_for_node;
    }
719
    loop_stack_.pop_back();
720
721
722
    return result;
  }

723
724
725
726
727
728
729
  Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BlockNode *op) final {
730
731
732
    ICHECK(0);
    return Stmt();
  }
733
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
734
735
736
737
738
739
740
741
742
    ICHECK(0);
    return Stmt();
  }

  struct SyncPattern {
    int release_idx, acquire_idx;
  };

  struct SyncPatternMap {
743
744
745
    std::vector<std::vector<int>> acquire;
    std::vector<std::vector<int>> release;
    std::vector<std::vector<bool>> release_after;
746
    std::vector<SyncPattern> patterns;
747
748
749
750
751
752
753
754
755
756

    void resize(size_t n) {
      acquire.resize(n);
      release.resize(n);
      release_after.resize(n);
    }

    bool is_loop_dependency(int pattern_idx) {
      return patterns[pattern_idx].release_idx >
             patterns[pattern_idx].acquire_idx;
757
758
759
    }
  };

760
761
762
  std::vector<SyncPattern>
  CreateBaseSyncPairs(Array<Stmt> seq_stmt,
                      const std::vector<bool> &is_producer) {
763
    const int n = seq_stmt.size();
764
    std::vector<std::set<const BufferNode *>> reads, writes;
765
766
767
    reads.reserve(n);
    writes.reserve(n);
    for (int i = 0; i < n; i++) {
768
769
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"",
770
771
                  /*body*/ seq_stmt[i]);
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
772
      std::set<const BufferNode *> read_set, write_set;
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
      for (auto region : access[0]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          read_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          read_set.insert(region->buffer.get());
        }
      }
      for (auto region : access[1]) {
        auto var = region->buffer->data;
        if (buffer_data_to_buffer_.count(var)) {
          write_set.insert(buffer_data_to_buffer_[var].get());
        } else {
          write_set.insert(region->buffer.get());
        }
      }
789
790
791
792
      reads.push_back(std::move(read_set));
      writes.push_back(std::move(write_set));
    }

793
794
    auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
                           const std::set<const BufferNode *> &rhs) {
795
      for (auto ptr : lhs)
796
797
        if (rhs.count(ptr))
          return true;
798
799
800
801
802
803
804
805
806
      return false;
    };

    std::vector<SyncPattern> sync_patterns;
    // producer_release consumer_acquire,
    // inject before the first consumer stmt for each producer
    for (int i = 0; i < n; i++) {
      for (int j = i + 1; j < n; j++) {
        if (is_producer[i] != is_producer[j] &&
807
808
            (intersect_fn(writes[i], reads[j]) ||
             intersect_fn(reads[i], writes[j]))) {
809
810
811
812
813
814
815
816
817
818
819
820
821
822
          sync_patterns.push_back({i, j});
          break;
        }
      }
    }

    // consumer_release producer_acquire
    // valid when is_loop is true
    // inject before the earliest producer stmt for each consumer
    bool in_loop = !is_zero(parity_);
    if (in_loop) {
      for (int i = 0; i < n; i++) {
        for (int j = 0; j < i; j++) {
          if (is_producer[i] != is_producer[j] &&
823
824
              (intersect_fn(writes[i], reads[j]) ||
               intersect_fn(reads[i], writes[j]))) {
825
826
827
828
829
830
831
832
833
834
            sync_patterns.push_back({i, j});
            break;
          }
        }
      }
    }

    return sync_patterns;
  }

835
836
837
  static std::vector<SyncPattern>
  RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
                           const std::vector<bool> &is_producer) {
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
    /*
      Simplify multiple release-acquire pairs into one
      ------------------
        Produce(A)
        Produce(B)
        Consume(A, B)
      ------------------
      [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)]

      Or
      ------------------
        Produce(A, B)
        Consume(A)
        Consume(B)
      ------------------
      [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)]
    */
    int M = sync_patterns.size();
    std::vector<bool> removed(M, false);
    for (int i = 0; i < M; i++) {
      for (int j = 0; j < M; j++) {
        if (is_producer[sync_patterns[i].acquire_idx] ==
                is_producer[sync_patterns[j].acquire_idx] &&
            sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx &&
            sync_patterns[i].release_idx < sync_patterns[j].release_idx)
          removed[i] = true;
      }
    }

    std::vector<SyncPattern> sync_pattern_cleaned;
    sync_pattern_cleaned.reserve(M);
    for (int i = 0; i < M; i++)
870
871
      if (!removed[i])
        sync_pattern_cleaned.push_back(sync_patterns[i]);
872
873
874
875
876
877
878
879
880
881
882
883
884

    return sync_pattern_cleaned;
  }

  SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) {
    size_t num_stmts = seq_stmt.size();
    std::vector<bool> is_producer;
    is_producer.reserve(num_stmts);
    for (auto stmt : seq_stmt) {
      is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer);
    }

    auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
885
886
    auto sync_patterns =
        RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
887
888

    // for (auto pattern : sync_patterns) {
889
890
    //   std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
    //   std::endl;
891
892
893
    // }

    SyncPatternMap map;
894
    map.resize(num_stmts);
895
    map.patterns = sync_patterns;
896

897
    for (size_t i = 0; i < sync_patterns.size(); i++) {
898
899
900
901
902
903
      int acquire_idx = sync_patterns[i].acquire_idx;
      int release_idx = sync_patterns[i].release_idx;

      map.acquire[acquire_idx].push_back(i);
      map.release[release_idx].push_back(i);
      map.release_after[release_idx].push_back(true);
904
905
    }

906
    std::vector<int> cur_consumer_barrier, cur_producer_barrier;
907
908
    for (int i = num_stmts - 1; i >= 0; i--) {
      if (is_producer[i]) {
909
910
911
912
913
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_producer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
914
        } else {
915
916
917
          for (auto pattern_idx : map.release[i]) {
            cur_producer_barrier.push_back(pattern_idx);
          }
918
919
        }
      } else {
920
921
922
923
924
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_consumer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
925
        } else {
926
927
928
          for (auto pattern_idx : map.release[i]) {
            cur_consumer_barrier.push_back(pattern_idx);
          }
929
930
931
932
933
934
935
936
937
        }
      }
    }
    return map;
  }

  const bool is_emitting_producer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  std::unordered_set<int> released_barrier_;
938
  const WarpSpecializedRoleMarker &marker_;
939
940
941
942
943

  int num_barriers_ = 0;
  PrimExpr parity_ = 0;
  PrimExpr stage_ = 0;
  int num_stages_ = 1;
944
  std::vector<std::pair<Var, PrimExpr>> loop_stack_;
945
  Var thread_var_;
946
  bool mbarrier_only_ = false;
947
948
949
950
  PipelineInfo pipeline_info_;
  friend class WarpSpecializedRewriter;
};

951
952
953
954
955
class SetMaxNRegCollector : public StmtExprVisitor {
public:
  static Array<IntImm> Collect(const PrimFunc &f) {
    SetMaxNRegCollector collector;
    collector(f->body);
956
957
958
959
    return collector.has_no_set_max_nreg_
               ? Array<IntImm>({IntImm(DataType::Int(32), -1),
                                IntImm(DataType::Int(32), -1)})
               : collector.nreg_;
960
961
962
963
964
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
965
      if (call->op.same_as(set_max_nreg())) {
966
967
968
969
970
971
972
973
974
        int reg_hint = call->args[0].as<IntImmNode>()->value;
        int is_inc = call->args[1].as<IntImmNode>()->value;
        ICHECK(reg_hint <= 240 && reg_hint >= 24)
            << "Invalid reg hint: " << reg_hint;
        ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;

        // producer should decrease register hint while consumer should increase
        // register hint
        nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
975
      } else if (call->op.same_as(no_set_max_nreg())) {
976
        has_no_set_max_nreg_ = true;
977
978
979
980
981
982
983
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
                      IntImm(DataType::Int(32), 0)};
984
  bool has_no_set_max_nreg_ = false;
985
986
};

987
class WarpSpecializedRewriter : public StmtExprMutator {
988
public:
989
990
991
  WarpSpecializedRewriter(bool disable_warp_specialized)
      : disable_warp_specialized_(disable_warp_specialized) {}
  static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized) {
992
993
994
    // Check if function only uses threadIdx.x before proceeding
    if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
      LOG(WARNING) << "WarpSpecialize will be disabled because the program "
995
                      "uses thread tags other than threadIdx.x."
996
997
998
999
1000
1001
                   << "If you want to use warp specialization, please refactor "
                      "your program to use threadIdx.x only";
      // Return original function unchanged if other thread tags are found
      return f;
    }

1002
    auto T = WarpSpecializedRewriter(disable_warp_specialized);
1003
    T.nreg_ = SetMaxNRegCollector::Collect(f);
1004
    T.buffer_lca_ = DetectBufferAccessLCA(f);
1005
1006
    for (auto [buffer, _] : T.buffer_lca_)
      T.buffer_data_to_buffer_.Set(buffer->data, buffer);
1007
1008
1009
1010
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

1011
1012
private:
  Stmt VisitStmt_(const AttrStmtNode *op) final {
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

1030
1031
  Stmt VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
1032
1033
      if (call->op.same_as(set_max_nreg()) ||
          call->op.same_as(no_set_max_nreg())) {
1034
1035
1036
1037
1038
1039
        return Evaluate(0);
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

1040
1041
1042
1043
  // If users define a thread binding, we will replace the thread binding with
  // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
  // the same as the thread extent
  Stmt VisitStmt_(const ForNode *op) final {
1044
1045
1046
1047
1048
1049
1050
    ICHECK(thread_iv_.defined());
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (for_node->kind == ForKind::kThreadBinding) {
      ICHECK(for_node->thread_binding.defined());
      String thread_tag = for_node->thread_binding.value()->thread_tag;
      ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
      Var thread_iv = Downcast<Var>(for_node->loop_var);
1051
1052
      Stmt new_body =
          ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
1053
1054
1055
1056
1057
      return new_body;
    }
    return for_node;
  }

1058
1059
1060
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
1061
1062
1063
1064
1065
1066
    if (!thread_iv_.defined()) {
      return block_realize;
    }

    Block block = block_realize->block;
    WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
1067
    marker.Prepare(block);
1068
1069
1070
1071
1072
1073
    marker(block);
    if (!marker.HasProducer()) {
      // Cannot detect any producer here, directly return.
      return block_realize;
    }

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
    if (disable_warp_specialized_) {
      WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
                                     marker, true);
      auto code = mbarrier_emitter(block->body);
      int num_barriers = mbarrier_emitter.num_barriers_;
      Array<PrimExpr> barrier_num_threads;
      barrier_num_threads.reserve(num_barriers);
      PrimExpr arrive_thread_count = thread_iv_->dom->extent;
      for (int i = 0; i < num_barriers; i++) {
        barrier_num_threads.push_back(arrive_thread_count);
      }
      Stmt init_barrier = Evaluate(Call(
1086
          DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1087
1088
1089
1090
      block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
      block_realize.CopyOnWrite()->block = block;
      return block_realize;
    }
1091
1092
1093
1094
1095
1096
1097
    WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
    WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
    Stmt producer_code = producer(block->body);
    Stmt consumer_code = consumer(block->body);
    PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
    PrimExpr producer_thread_extent = thread_iv_->dom->extent;
    // Need one warp-group for bulk-copy only case
1098
1099
    if (!marker.HasSimtCopy())
      producer_thread_extent = 128;
1100
1101

    // TODO: estimate the correct reg usage.
1102
1103
1104
    int dec_reg = nreg_[0].as<IntImmNode>()->value;
    int inc_reg = nreg_[1].as<IntImmNode>()->value;

1105
1106
1107
    auto inc_reg_stmt = Evaluate(0);
    auto dec_reg_stmt = Evaluate(0);
    if (dec_reg >= 0 && inc_reg >= 0) {
1108
      inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
1109
                                   {inc_reg == 0 ? 240 : inc_reg, 1}));
1110
      dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
1111
1112
                                   {dec_reg == 0 ? 24 : dec_reg, 0}));
    }
1113
1114
1115
1116

    producer_code = SeqStmt({dec_reg_stmt, producer_code});
    consumer_code = SeqStmt({inc_reg_stmt, consumer_code});

1117
1118
1119
    producer_code =
        ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
                                   thread_iv_->var - consumer_thread_extent);
1120
1121
1122
1123
1124
1125
1126
1127
1128
    updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
    need_update_thread_extent_ = true;

    ICHECK(producer.num_barriers_ == consumer.num_barriers_)
        << producer.num_barriers_ << " " << consumer.num_barriers_;
    int num_barriers = consumer.num_barriers_;
    Array<PrimExpr> barrier_num_threads;
    barrier_num_threads.reserve(num_barriers);
    for (int i = 0; i < num_barriers; i++) {
1129
1130
1131
      PrimExpr arrive_thread_count = producer.released_barrier_.count(i)
                                         ? producer_thread_extent
                                         : consumer_thread_extent;
1132
1133
1134
      barrier_num_threads.push_back(arrive_thread_count);
    }

1135
    Stmt init_barrier = Evaluate(Call(
1136
        DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
1137
1138
    Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
                           producer_code, consumer_code);
1139
    // Add an attr here to handle the partial thread count in ThreadSync pass.
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
                                  Downcast<IntImm>(consumer_thread_extent)};
    body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body);

    block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

  WarpSpecializedRewriter() = default;

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
1157
  bool disable_warp_specialized_ = false;
1158
  Array<IntImm> nreg_;
1159
1160
};

1161
1162
1163
1164
1165
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
  static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
    WarpSpecializedDetector detector;
    detector.VisitStmt(stmt);
1166
1167
    return detector.has_warp_specialization_ ||
           (detector.has_tma_op_ && detector.has_mbarrier_op_);
1168
1169
1170
1171
1172
  }

  WarpSpecializedDetector() {
    has_tma_op_ = false;
    has_mbarrier_op_ = false;
1173
    has_warp_specialization_ = false;
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(create_list_of_mbarrier()) ||
          call->op.same_as(mbarrier_wait_parity()) ||
          call->op.same_as(builtin::ptx_arrive_barrier()) ||
          call->op.same_as(builtin::ptx_cp_async_barrier())) {
        has_mbarrier_op_ = true;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

  void VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
        op->op.same_as(set_max_nreg())) {
      has_tma_op_ = true;
    }
    IRVisitorWithAnalyzer::VisitExpr_(op);
  }

1197
  void VisitStmt_(const AttrStmtNode *op) final {
1198
1199
1200
1201
    if (op->attr_key == "warp_specialize" &&
        op->value.as<IntImmNode>()->value == 1) {
      has_warp_specialization_ = true;
    }
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    IRVisitorWithAnalyzer::VisitStmt_(op);
  }

1212
  bool has_tma_op_{false};
1213
  IterVar thread_var_;
1214
  bool has_mbarrier_op_{false};
1215
  bool has_warp_specialization_{false};
1216
1217
};

1218
1219
1220
1221
using namespace tir::transform;

tvm::transform::Pass WarpSpecialized() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1222
1223
    bool disable_warp_specialized =
        ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
1224
1225
1226
1227
1228
1229
    bool warp_specialized = WarpSpecializedDetector::Detect(f->body);

    if (!warp_specialized) {
      return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
    }
    return f;
1230
1231
1232
1233
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}

1234
1235
1236
1237
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
1238

1239
1240
} // namespace tl
} // namespace tvm