warp_specialized_rewriter.cc 43.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file warp_specialized_pipeline.cc
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

25
#include "tir/analysis/var_use_def_analysis.h"
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

enum class Role { kConsumer, kProducer, kBoth };

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class TMAFinder : public StmtExprVisitor {
public:
  void clear() { has_tma_load_ = false; }

  void VisitExpr_(const CallNode *call) final {
    if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
      has_tma_load_ = true;
    }
  }

  bool has_tma_load_ = false;
};

class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
  auto FindProducerusedBuffer(Stmt stmt) {
    VisitStmt(stmt);
    return used_in_producer_cond_;
  }

  void InsertBuffer(const PrimExpr &expr) {
    // Find the buffer that is used in the condition
    VarUseDefAnalyzer usage(Array<Var>{});
    usage(expr);
    for (const auto &buffer : usage.buffer_use_count_) {
      used_in_producer_cond_.insert(buffer.first);
    }
    for (const auto &buffer : used_in_producer_cond_) {
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    TMAFinder tma_finder;
    tma_finder(op->then_case);
    if (op->else_case.defined()) {
      tma_finder(op->else_case.value());
    }
    if (tma_finder.has_tma_load_) {
      InsertBuffer(op->condition);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
    TMAFinder tma_finder;
    tma_finder(op->body);
    if (tma_finder.has_tma_load_) {
      InsertBuffer(op->min);
      InsertBuffer(op->extent);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

private:
  std::unordered_set<const BufferNode *> used_in_producer_cond_;
};

98
class WarpSpecializedRoleMarker : public StmtVisitor {
99
public:
100
101
102
  WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

103
104
105
106
107
  void Prepare(const Stmt &stmt) {
    ProducerUsedBufferFinder finder;
    used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt);
  }

108
  Role GetRole(const StmtNode *stmt) const {
109
110
111
112
113
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

114
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
115

116
  void VisitStmt_(const EvaluateNode *op) final {
117
118
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
119
120
      if (call->op.same_as(TMALoadOp()) ||
          call->op.same_as(TMALoadIm2ColOp())) {
121
122
123
124
125
126
127
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
    }
    SetRole(op, role);
  }

128
129
130
  void VisitStmt_(const BufferStoreNode *op) final {
    bool is_shared_store =
        op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
131
132
133
134
    if (used_in_producer_cond_.count(op->buffer.get())) {
      SetRole(op, Role::kBoth);
      return;
    }
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
152
153
    if (role == Role::kProducer)
      has_simt_copy_ = true;
154
155
156
    SetRole(op, role);
  }

157
  void VisitStmt_(const SeqStmtNode *op) final {
158
159
160
161
162
163
164
165
166
167
168
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

169
  void VisitStmt_(const IfThenElseNode *op) final {
170
171
172
173
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
174
175
      if (role != role_else)
        role = Role::kBoth;
176
177
178
179
    }
    SetRole(op, role);
  }

180
  void VisitStmt_(const BlockRealizeNode *op) final {
181
182
183
184
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

185
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
186
187
188
189
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

190
191
192
193
194
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
195
196
197
198
199

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

200
201
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
202
  Map<Var, Buffer> buffer_data_to_buffer_;
203
  std::unordered_map<const StmtNode *, Role> map_;
204
205
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
206
  std::unordered_set<const BufferNode *> used_in_producer_cond_;
207
208
209
210
211
212
213
};

static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
  return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
}

static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
214
215
  auto call = Call(DataType::Handle(), MBarrierExpectTX(),
                   {makeGetBarrier(barrier_id), bytes});
216
217
218
219
  return Evaluate(call);
}

static Stmt makeArriveBarrier(PrimExpr barrier_id) {
220
221
  auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
                   {makeGetBarrier(barrier_id)});
222
223
224
225
  return Evaluate(call);
}

static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
226
227
  auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
                   {makeGetBarrier(barrier_id)});
228
229
230
231
  return Evaluate(call);
}

static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
232
233
  auto call = Call(DataType::Handle(), MBarrierWaitParity(),
                   {makeGetBarrier(barrier_id), parity});
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
  return Evaluate(call);
}

// static bool isGemm(Stmt stmt) {
//   bool is_gemm = false;
//   if (stmt.as<EvaluateNode>()) {
//     auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
//     if (call && call->op.same_as(Op::Get("tir.call_extern"))) {
//       if (call->args[0].as<StringImmNode>()) {
//         std::string name = Downcast<StringImm>(call->args[0])->value;
//         if (name.find("gemm") != std::string::npos) {
//           is_gemm = true;
//         }
//       }
//     }
//   }
//   return is_gemm;
// }

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
class TMAExpectTxRewriter : public StmtExprMutator {
public:
  TMAExpectTxRewriter(Stmt expect_tx) : expect_tx_(expect_tx) {}
  static Stmt Rewrite(Stmt stmt, Stmt expect_tx) {
    TMAExpectTxRewriter rewriter(expect_tx);
    return rewriter(stmt);
  }

private:
  Stmt VisitStmt_(const ForNode *op) final {
    insert_in_evaluate_ = false;
    StmtExprMutator::VisitStmt_(op);
    insert_in_evaluate_ = true;
    if (contain_tma_load_) {
      Array<Stmt> new_seq = {expect_tx_, GetRef<For>(op)};
      contain_tma_load_ = false;
      return SeqStmt(std::move(new_seq));
    }
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(TMALoadOp()) ||
          call->op.same_as(TMALoadIm2ColOp())) {
        contain_tma_load_ = true;
        if (insert_in_evaluate_) {
          Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)};
          return SeqStmt(std::move(new_seq));
        }
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt expect_tx_;
  bool contain_tma_load_;
290
  bool insert_in_evaluate_ = true;
291
292
};

293
class ProducerTraitsCollector : public StmtExprVisitor {
294
public:
295
296
297
298
299
300
301
302
303
304
305
306
307
308
  ProducerTraitsCollector() { Clear(); }

  void Clear() {
    bulk_copy_bytes = 0;
    loop_extents = 1;
    has_simt_copy = false;
  }

  void Collect(Stmt stmt) { VisitStmt(stmt); }

  bool HasSimtCopy() { return has_simt_copy; }

  PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }

309
310
private:
  void VisitExpr_(const CallNode *call) final {
311
312
313
314
315
316
317
318
319
    if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
      Call access_ptr = Downcast<Call>(call->args[2]);
      ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
      int type_bytes = access_ptr->args[0]->dtype.bytes();
      bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
    }
    StmtExprVisitor::VisitExpr_(call);
  }

320
  void VisitStmt_(const ForNode *op) final {
321
322
323
324
325
326
    PrimExpr old_loop_evtents = loop_extents;
    loop_extents *= op->extent;
    StmtExprVisitor::VisitStmt_(op);
    loop_extents = old_loop_evtents;
  }

327
328
329
330
331
332
333
334
335
336
337
338
  void VisitStmt_(const IfThenElseNode *op) final {
    bool old_in_if_cond = in_if_cond_;
    in_if_cond_ = true;
    VisitExpr(op->condition);
    in_if_cond_ = old_in_if_cond;

    VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      VisitStmt(op->else_case.value());
    }
  }

339
  void VisitExpr_(const BufferLoadNode *op) final {
340
341
342
    if (!in_if_cond_) {
      has_simt_copy = true;
    }
343
344
345
346
347
348
    StmtExprVisitor::VisitExpr_(op);
  }

  bool has_simt_copy;
  PrimExpr bulk_copy_bytes;
  PrimExpr loop_extents;
349
  bool in_if_cond_ = false;
350
351
352
353
};

// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
354
public:
355
356
357
358
359
360
  static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
    MbarrierRewriter rewriter;
    rewriter.producer_barrier_idx_ = barrier_id;
    return rewriter(stmt);
  }

361
362
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
363
364
365
366
367
368
369
370
371
372
373
374
    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
      Call access_ptr = Downcast<Call>(call->args[2]);
      ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
      call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
    }
    return call;
  }
  PrimExpr producer_barrier_idx_;
};

class ThreadIdxRewriter : public StmtExprMutator {
375
public:
376
377
378
379
380
  static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
    auto rewriter = ThreadIdxRewriter(thread_var, replaced);
    return rewriter(stmt);
  }

381
private:
382
383
384
  ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
      : thread_var_(thread_var), replaced_(replaced) {}

385
  PrimExpr VisitExpr_(const VarNode *var) final {
386
387
388
389
390
391
392
393
394
395
396
    if (var == thread_var_.get()) {
      return replaced_;
    } else {
      return StmtExprMutator::VisitExpr_(var);
    }
  }

  Var thread_var_;
  PrimExpr replaced_;
};

397
398
399
400
401
402
Block MakeGroupBlock(const Stmt &stmt,
                     const Map<String, ObjectRef> &annotations) {
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ stmt,
              /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
              /*annotations=*/annotations);
403
404
405
406
407
408
409
410
411
412
413
  return block;
}

struct OpInfo {
  int group_size, order, stage;
  std::vector<int> group;
};
struct PipelineInfo {
  std::vector<OpInfo> op_infos;

  PipelineInfo() = default;
414
415
  PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
               Array<Integer> stage_info) {
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    int n = static_cast<int>(group_info.size());
    ICHECK(n == static_cast<int>(order_info.size()));
    ICHECK(n == static_cast<int>(stage_info.size()));
    // int cur_id = 0;
    for (int i = 0; i < n; i++) {
      OpInfo op_info;
      op_info.group_size = group_info[i].size();
      for (int j = 0; j < op_info.group_size; j++) {
        op_info.group.push_back(group_info[i][j].as<IntImmNode>()->value);
      }
      op_info.order = order_info[i].as<IntImmNode>()->value;
      op_info.stage = stage_info[i].as<IntImmNode>()->value;
      op_infos.push_back(op_info);
    }
  }

432
  PipelineInfo(const PipelineInfo &other) {
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    for (auto op_info : other.op_infos) {
      op_infos.push_back(op_info);
    }
  }

  std::pair<int, int> FindStmt(int stmt_idx) {
    for (size_t i = 0; i < op_infos.size(); i++) {
      for (size_t j = 0; j < op_infos[i].group.size(); j++) {
        if (op_infos[i].group[j] == stmt_idx) {
          return std::make_pair(i, j);
        }
      }
    }
    return std::make_pair(-1, -1);
  }

  void UpdateOrder(int order) {
    for (int i = 0; i < static_cast<int>(op_infos.size()); i++) {
      if (op_infos[i].order >= order && op_infos[i].order > 0) {
        op_infos[i].order++;
      }
    }
  }

  int SplitOp(int stmt_idx) {
    auto pair = FindStmt(stmt_idx);
    int op_idx = pair.first;
    int inner_idx = pair.second;
    ICHECK(op_idx != -1);
    ICHECK(inner_idx != -1);
    OpInfo half0;
    OpInfo half1;
    // The order to do sync
    int sync_order = op_infos[op_idx].order + 1;
    UpdateOrder(sync_order);

    half0.group_size = inner_idx + 1;
    half0.order = op_infos[op_idx].order;
    half0.stage = op_infos[op_idx].stage;
    for (int i = 0; i <= inner_idx; i++) {
      half0.group.push_back(op_infos[op_idx].group[i]);
    }
    half1.group_size = op_infos[op_idx].group_size - inner_idx - 1;
    half1.order = op_infos[op_idx].order + 2;
    half1.stage = op_infos[op_idx].stage;
    for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) {
      half1.group.push_back(op_infos[op_idx].group[i]);
    }
    op_infos.erase(op_infos.begin() + op_idx);
    if (half0.group_size > 0) {
      op_infos.insert(op_infos.begin() + op_idx, half0);
    }
    if (half1.group_size > 0) {
      UpdateOrder(half1.order);
      op_infos.insert(op_infos.begin() + op_idx + 1, half1);
    }
    return sync_order;
  }

  void PrintPipelineInfo() {
    std::cout << "Print op_infos:" << std::endl;
    for (size_t i = 0; i < op_infos.size(); i++) {
495
496
      std::cout << i << " " << op_infos[i].group_size << " "
                << op_infos[i].order << " " << op_infos[i].stage << std::endl;
497
498
499
500
501
502
    }
    std::cout << "End of print" << std::endl;
  }
};

class GroupOpRewriter : public StmtExprMutator {
503
public:
504
505
  GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}

506
507
private:
  Stmt VisitStmt_(const ForNode *op) final {
508
509
510
511
512
513
514
515
516
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));
    auto original_node = (op->body).as<SeqStmtNode>();
    if (!original_node) {
      return GetRef<For>(op);
    }
    Array<Stmt> new_body;
    int cur_id = 0;
    for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
517
518
      if (pipeline_info_.op_infos[i].group_size == 0)
        continue;
519
      Array<Stmt> block_stmt;
520
521
      for (int j = 0;
           j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
522
        // ICHECK(group_info_[i][j].as<IntImmNode>());
523
524
        // int index =
        // static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
525
526
527
528
529
530
        ICHECK(original_node->seq[cur_id].as<BlockNode>());
        auto block = original_node->seq[cur_id].as<BlockNode>();
        // TODO: handle nested seqstmt
        block_stmt.push_back(block->body);
        cur_id++;
      }
531
532
533
534
      new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                            ? block_stmt[0]
                                            : SeqStmt(std::move(block_stmt)),
                                        annotations));
535
536
537
538
539
540
541
542
543
544
545
    }
    Array<Integer> order_anno;
    Array<Integer> stage_anno;
    for (auto op_info : pipeline_info_.op_infos) {
      order_anno.push_back(Integer(op_info.order));
      stage_anno.push_back(Integer(op_info.stage));
    }
    Map<String, ObjectRef> for_annotations = op->annotations;
    for_annotations.erase("tl_pipeline_group");
    for_annotations.Set("software_pipeline_order", order_anno);
    for_annotations.Set("software_pipeline_stage", stage_anno);
546
547
548
549
    For new_for =
        For(op->loop_var, op->min, op->extent, op->kind,
            new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
            op->thread_binding, for_annotations);
550
551
552
553
554
555
    return new_for;
  }

  PipelineInfo pipeline_info_;
};
class WSCodeEmitter : public StmtMutator {
556
public:
557
  WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
558
                Map<Var, Buffer> buffer_data_to_buffer,
559
560
                const WarpSpecializedRoleMarker &marker,
                bool mbarrier_only = false)
561
      : is_emitting_producer_(is_emitting_producer),
562
        buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
563
        thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
564

565
566
private:
  template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
567
    Role role = marker_.GetRole(op);
568
569
570
571
572
    if (mbarrier_only_) {
      if (role != Role::kProducer)
        return StmtMutator::VisitStmt_(op);
    }
    if (role == Role::kBoth) {
573
      return StmtMutator::VisitStmt_(op);
574
    } else if ((role == Role::kProducer) == is_emitting_producer_) {
575
      return GetRef<Stmt>(op);
576
    } else {
577
      return Evaluate(0);
578
    }
579
580
581
  }

  // TODO: only need to add block for ops in the loop
582
  Stmt VisitStmt_(const SeqStmtNode *op) final {
583

584
585
586
587
588
589
590
    bool has_producer = false;
    for (auto stmt : op->seq) {
      if (marker_.GetRole(stmt) == Role::kProducer) {
        has_producer = true;
        break;
      }
    }
591
592
593
594
    bool need_producer_sync =
        has_producer && marker_.GetRole(op) == Role::kBoth;
    if (!need_producer_sync)
      return FilterByRole(op);
595

596
597
    auto seq_transformed =
        op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
598
599

    auto map = ExtractSyncPattern(op->seq);
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    /*
      std::cout << "Print ExtractSyncPattern" << std::endl;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
        << map.release_after[i] << std::endl;
      }
      std::cout << "Print sync pattern" << std::endl;
      for (auto pattern : map.patterns) {
        std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
        std::endl;
      }
      std::cout << "End of ExtractSyncPattern" << std::endl;
      pipeline_info_.PrintPipelineInfo();
    */
614
615
616
617
    Array<Stmt> new_body;
    Map<String, ObjectRef> annotations;
    annotations.Set(String("stmt_group"), Integer(1));

618
    if (is_emitting_producer_) { // producer case
619
620
621
      ProducerTraitsCollector collector;
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
622
623
624
625
626
627
628
629
630
631
632
        if (!mbarrier_only_) {
          if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
            continue;
          if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
            block_stmt.push_back(seq_transformed[i]);
            new_body.push_back(MakeGroupBlock(
                block_stmt.size() == 1 ? block_stmt[0]
                                       : SeqStmt(std::move(block_stmt)),
                annotations));
            continue;
          }
633
        }
634

635
        for (int pattern_idx : map.acquire[i]) {
636
          PrimExpr acquire_barrier_id =
637
638
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
639
640
                                ? bitwise_xor(parity_, 1)
                                : parity_;
641
642
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        ICHECK(map.release[i].size() > 0);
        for (size_t j = 0; j < map.release[i].size(); j++) {
          int pattern_idx = map.release[i][j];
          PrimExpr release_barrier_id =
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          auto stmt =
              MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
          collector.Collect(stmt);
          if (!is_zero(collector.BulkCopyBytes())) {
            auto expect_tx = IfThenElse(
                EQ(thread_var_, 0),
                makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
            block_stmt.push_back(TMAExpectTxRewriter::Rewrite(stmt, expect_tx));
          } else {
            block_stmt.push_back(stmt);
          }
          if (collector.HasSimtCopy() > 0) {
            block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
661
          }
662
663
664
665
666
667
668
669
670
671
672
673
          if (map.release_after[i][j]) {
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
          }
          collector.Clear();
          new_body.push_back(MakeGroupBlock(
              block_stmt.size() == 1 ? block_stmt[0]
                                     : SeqStmt(std::move(block_stmt)),
              annotations));
674
675
        }
      }
676
    } else { // consumer case
677
678
      for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
        Array<Stmt> block_stmt = {};
679
680
        if (marker_.GetRole(op->seq[i]) == Role::kProducer)
          continue;
681
        for (int pattern_idx : map.acquire[i]) {
682
          PrimExpr acquire_barrier_id =
683
684
              stage_ + num_barriers_ + num_stages_ * pattern_idx;
          PrimExpr parity = map.is_loop_dependency(pattern_idx)
685
686
                                ? bitwise_xor(parity_, 1)
                                : parity_;
687
688
689
          block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
        }
        block_stmt.push_back(seq_transformed[i]);
690
691
692
693
694
695
696
697
698
699
        for (size_t j = 0; j < map.release[i].size(); j++) {
          if (map.release_after[i][j]) {
            int pattern_idx = map.release[i][j];
            PrimExpr release_barrier_id =
                stage_ + num_barriers_ + num_stages_ * pattern_idx;
            block_stmt.push_back(makeArriveBarrier(release_barrier_id));
            for (int s = 0; s < num_stages_; s++) {
              released_barrier_.insert(s + num_barriers_ +
                                       num_stages_ * pattern_idx);
            }
700
701
          }
        }
702
703
704
705
        new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
                                              ? block_stmt[0]
                                              : SeqStmt(std::move(block_stmt)),
                                          annotations));
706
707
708
709
      }
      // Filter out the producer stmts
      int cur_id = 0;
      PipelineInfo new_pipeline_info;
710
711
      for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
           i++) {
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        auto op_info = pipeline_info_.op_infos[i];
        bool is_producer = false;
        for (int j = 0; j < op_info.group_size; j++) {
          if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) {
            is_producer = true;
          }
          cur_id++;
        }
        if (is_producer) {
          ICHECK(op_info.group_size == 1);
        } else {
          new_pipeline_info.op_infos.push_back(op_info);
        }
      }
      pipeline_info_ = new_pipeline_info;
    }

    num_barriers_ += map.patterns.size() * num_stages_;

    ICHECK(new_body.size() > 0);
    return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
  }

735
  Stmt VisitStmt_(const ForNode *op) final {
736
737
738
739
740
741
742
743
744
745
746
    int num_stages = 1;
    auto num_stages_anno = op->annotations.Get("num_stages");
    if (num_stages_anno.defined()) {
      ICHECK(num_stages_anno.as<IntImmNode>());
      num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
      ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
    }

    Array<Array<Integer>> group_info_array;
    Array<Integer> order_info_array;
    Array<Integer> stage_info_array;
747

748
749
750
751
752
753
754
755
756
757
758
759
760
    auto group_anno = op->annotations.Get("tl_pipeline_group");
    if (group_anno.defined()) {
      group_info_array = Downcast<Array<Array<Integer>>>(group_anno);
    }
    auto order_anno = op->annotations.Get("tl_pipeline_order");
    if (order_anno.defined()) {
      order_info_array = Downcast<Array<Integer>>(order_anno);
    }
    auto stage_anno = op->annotations.Get("tl_pipeline_stage");
    if (stage_anno.defined()) {
      stage_info_array = Downcast<Array<Integer>>(stage_anno);
    }

761
762
    PipelineInfo pipeline_info(group_info_array, order_info_array,
                               stage_info_array);
763
    if (pipeline_info.op_infos.size() > 0) {
764
765
      ICHECK(pipeline_info_.op_infos.size() == 0)
          << "Nested pipeline not supported.";
766
767
768
769
770
771
772
773
774
775
    }

    PrimExpr parity_before = std::move(parity_);
    PrimExpr stage_before = std::move(stage_);
    int num_stages_before = num_stages_;
    PipelineInfo pipeline_info_before = pipeline_info_;

    num_stages_ = num_stages;
    pipeline_info_ = pipeline_info;
    stage_ = FloorMod(op->loop_var - op->min, num_stages);
776
777
778
    parity_ = FloorMod(parity_before * op->extent +
                           FloorDiv(op->loop_var - op->min, num_stages),
                       2);
779
780
781
782

    auto result = FilterByRole(op);

    Stmt grouped_for_node;
783
784
    if (result.as<ForNode>() && group_anno.defined() &&
        group_info_array.size() > 0 && !is_emitting_producer_) {
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
      GroupOpRewriter group_op_rewriter(pipeline_info_);
      auto for_node = Downcast<For>(result);
      grouped_for_node = group_op_rewriter(for_node);
    }

    parity_ = std::move(parity_before);
    stage_ = std::move(stage_before);
    num_stages_ = num_stages_before;
    pipeline_info_ = pipeline_info_before;

    // remove pipeline annotation
    auto for_node = result.as<For>();
    if (result.as<ForNode>()) {
      auto for_node = Downcast<For>(result);
      for_node.CopyOnWrite()->annotations.erase("num_stages");
      if (is_emitting_producer_ || group_info_array.size() == 0) {
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
        for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
      }
804
805
      if (is_emitting_producer_ || !group_anno.defined() ||
          group_info_array.size() == 0) {
806
807
808
809
810
811
812
        return for_node;
      }
      return grouped_for_node;
    }
    return result;
  }

813
814
815
816
817
818
819
  Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
  Stmt VisitStmt_(const BlockNode *op) final {
820
821
822
    ICHECK(0);
    return Stmt();
  }
823
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
824
825
826
827
828
829
830
831
832
    ICHECK(0);
    return Stmt();
  }

  struct SyncPattern {
    int release_idx, acquire_idx;
  };

  struct SyncPatternMap {
833
834
835
    std::vector<std::vector<int>> acquire;
    std::vector<std::vector<int>> release;
    std::vector<std::vector<bool>> release_after;
836
    std::vector<SyncPattern> patterns;
837
838
839
840
841
842
843
844
845
846

    void resize(size_t n) {
      acquire.resize(n);
      release.resize(n);
      release_after.resize(n);
    }

    bool is_loop_dependency(int pattern_idx) {
      return patterns[pattern_idx].release_idx >
             patterns[pattern_idx].acquire_idx;
847
848
849
    }
  };

850
851
852
  std::vector<SyncPattern>
  CreateBaseSyncPairs(Array<Stmt> seq_stmt,
                      const std::vector<bool> &is_producer) {
853
    const int n = seq_stmt.size();
854
    std::vector<std::set<const BufferNode *>> reads, writes;
855
856
857
    reads.reserve(n);
    writes.reserve(n);
    for (int i = 0; i < n; i++) {
858
859
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"",
860
861
                  /*body*/ seq_stmt[i]);
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
862
863
864
865
866
      std::set<const BufferNode *> read_set, write_set;
      for (auto region : access[0])
        read_set.insert(region->buffer.get());
      for (auto region : access[1])
        write_set.insert(region->buffer.get());
867
868
869
870
      reads.push_back(std::move(read_set));
      writes.push_back(std::move(write_set));
    }

871
872
    auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
                           const std::set<const BufferNode *> &rhs) {
873
      for (auto ptr : lhs)
874
875
        if (rhs.count(ptr))
          return true;
876
877
878
879
880
881
882
883
884
      return false;
    };

    std::vector<SyncPattern> sync_patterns;
    // producer_release consumer_acquire,
    // inject before the first consumer stmt for each producer
    for (int i = 0; i < n; i++) {
      for (int j = i + 1; j < n; j++) {
        if (is_producer[i] != is_producer[j] &&
885
886
            (intersect_fn(writes[i], reads[j]) ||
             intersect_fn(reads[i], writes[j]))) {
887
888
889
890
891
892
893
894
895
896
897
898
899
900
          sync_patterns.push_back({i, j});
          break;
        }
      }
    }

    // consumer_release producer_acquire
    // valid when is_loop is true
    // inject before the earliest producer stmt for each consumer
    bool in_loop = !is_zero(parity_);
    if (in_loop) {
      for (int i = 0; i < n; i++) {
        for (int j = 0; j < i; j++) {
          if (is_producer[i] != is_producer[j] &&
901
902
              (intersect_fn(writes[i], reads[j]) ||
               intersect_fn(reads[i], writes[j]))) {
903
904
905
906
907
908
909
910
911
912
            sync_patterns.push_back({i, j});
            break;
          }
        }
      }
    }

    return sync_patterns;
  }

913
914
915
  static std::vector<SyncPattern>
  RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
                           const std::vector<bool> &is_producer) {
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    /*
      Simplify multiple release-acquire pairs into one
      ------------------
        Produce(A)
        Produce(B)
        Consume(A, B)
      ------------------
      [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)]

      Or
      ------------------
        Produce(A, B)
        Consume(A)
        Consume(B)
      ------------------
      [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)]
    */
    int M = sync_patterns.size();
    std::vector<bool> removed(M, false);
    for (int i = 0; i < M; i++) {
      for (int j = 0; j < M; j++) {
        if (is_producer[sync_patterns[i].acquire_idx] ==
                is_producer[sync_patterns[j].acquire_idx] &&
            sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx &&
            sync_patterns[i].release_idx < sync_patterns[j].release_idx)
          removed[i] = true;
      }
    }

    std::vector<SyncPattern> sync_pattern_cleaned;
    sync_pattern_cleaned.reserve(M);
    for (int i = 0; i < M; i++)
948
949
      if (!removed[i])
        sync_pattern_cleaned.push_back(sync_patterns[i]);
950
951
952
953
954
955
956
957
958
959
960
961
962

    return sync_pattern_cleaned;
  }

  SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) {
    size_t num_stmts = seq_stmt.size();
    std::vector<bool> is_producer;
    is_producer.reserve(num_stmts);
    for (auto stmt : seq_stmt) {
      is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer);
    }

    auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
963
964
    auto sync_patterns =
        RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
965
966

    // for (auto pattern : sync_patterns) {
967
968
    //   std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
    //   std::endl;
969
970
971
    // }

    SyncPatternMap map;
972
    map.resize(num_stmts);
973
    map.patterns = sync_patterns;
974

975
    for (size_t i = 0; i < sync_patterns.size(); i++) {
976
977
978
979
980
981
      int acquire_idx = sync_patterns[i].acquire_idx;
      int release_idx = sync_patterns[i].release_idx;

      map.acquire[acquire_idx].push_back(i);
      map.release[release_idx].push_back(i);
      map.release_after[release_idx].push_back(true);
982
983
    }

984
    std::vector<int> cur_consumer_barrier, cur_producer_barrier;
985
986
    for (int i = num_stmts - 1; i >= 0; i--) {
      if (is_producer[i]) {
987
988
989
990
991
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_producer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
992
        } else {
993
994
995
          for (auto pattern_idx : map.release[i]) {
            cur_producer_barrier.push_back(pattern_idx);
          }
996
997
        }
      } else {
998
999
1000
1001
1002
        if (map.release[i].size() == 0) {
          for (auto pattern_idx : cur_consumer_barrier) {
            map.release[i].push_back(pattern_idx);
            map.release_after[i].push_back(false);
          }
1003
        } else {
1004
1005
1006
          for (auto pattern_idx : map.release[i]) {
            cur_consumer_barrier.push_back(pattern_idx);
          }
1007
1008
1009
1010
1011
1012
1013
1014
1015
        }
      }
    }
    return map;
  }

  const bool is_emitting_producer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  std::unordered_set<int> released_barrier_;
1016
  const WarpSpecializedRoleMarker &marker_;
1017
1018
1019
1020
1021
1022

  int num_barriers_ = 0;
  PrimExpr parity_ = 0;
  PrimExpr stage_ = 0;
  int num_stages_ = 1;
  Var thread_var_;
1023
  bool mbarrier_only_ = false;
1024
1025
1026
1027
  PipelineInfo pipeline_info_;
  friend class WarpSpecializedRewriter;
};

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
class ThreadTagChecker : public StmtExprVisitor {
public:
  static bool HasOnlyThreadIdxX(const PrimFunc &f) {
    ThreadTagChecker checker;
    checker(f->body);
    return checker.is_valid_;
  }

private:
  void VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == tir::attr::thread_extent) {
1039
1040
1041
1042
1043
1044
      IterVar iter_var = Downcast<IterVar>(op->node);
      String thread_tag = iter_var->thread_tag;
      bool is_y_or_z =
          thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";

      if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) {
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        is_valid_ = false;
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
    if (op->kind == ForKind::kThreadBinding) {
      ICHECK(op->thread_binding.defined());
      String thread_tag = op->thread_binding.value()->thread_tag;
1055
1056
1057
1058
1059
1060
1061
1062
      bool is_y_or_z =
          thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
      if (!thread_tag.empty() && is_y_or_z) {
        auto iter_var = Downcast<IterVar>(op->thread_binding);
        if (iter_var.defined() && iter_var->dom.defined() &&
            !is_one(iter_var->dom->extent)) {
          is_valid_ = false;
        }
1063
1064
1065
1066
1067
1068
1069
1070
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  bool is_valid_ = true;
};

1071
1072
1073
1074
1075
class SetMaxNRegCollector : public StmtExprVisitor {
public:
  static Array<IntImm> Collect(const PrimFunc &f) {
    SetMaxNRegCollector collector;
    collector(f->body);
1076
1077
1078
1079
    return collector.has_no_set_max_nreg_
               ? Array<IntImm>({IntImm(DataType::Int(32), -1),
                                IntImm(DataType::Int(32), -1)})
               : collector.nreg_;
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(SetMaxNReg())) {
        int reg_hint = call->args[0].as<IntImmNode>()->value;
        int is_inc = call->args[1].as<IntImmNode>()->value;
        ICHECK(reg_hint <= 240 && reg_hint >= 24)
            << "Invalid reg hint: " << reg_hint;
        ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;

        // producer should decrease register hint while consumer should increase
        // register hint
        nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
1095
1096
      } else if (call->op.same_as(NoSetMaxNReg())) {
        has_no_set_max_nreg_ = true;
1097
1098
1099
1100
1101
1102
1103
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
                      IntImm(DataType::Int(32), 0)};
1104
  bool has_no_set_max_nreg_ = false;
1105
1106
};

1107
class WarpSpecializedRewriter : public StmtExprMutator {
1108
public:
1109
1110
1111
  WarpSpecializedRewriter(bool disable_warp_specialized)
      : disable_warp_specialized_(disable_warp_specialized) {}
  static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized) {
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    // Check if function only uses threadIdx.x before proceeding
    if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
      LOG(WARNING) << "WarpSpecialize will be disabled because the program "
                      "uses thread tags other than threadIdx.x\n"
                   << "If you want to use warp specialization, please refactor "
                      "your program to use threadIdx.x only";
      // Return original function unchanged if other thread tags are found
      return f;
    }

1122
    auto T = WarpSpecializedRewriter(disable_warp_specialized);
1123
    T.nreg_ = SetMaxNRegCollector::Collect(f);
1124
    T.buffer_lca_ = DetectBufferAccessLCA(f);
1125
1126
    for (auto [buffer, _] : T.buffer_lca_)
      T.buffer_data_to_buffer_.Set(buffer->data, buffer);
1127
1128
1129
1130
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

1131
1132
private:
  Stmt VisitStmt_(const AttrStmtNode *op) final {
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

1150
1151
  Stmt VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
1152
      if (call->op.same_as(SetMaxNReg()) || call->op.same_as(NoSetMaxNReg())) {
1153
1154
1155
1156
1157
1158
        return Evaluate(0);
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

1159
1160
1161
1162
  // If users define a thread binding, we will replace the thread binding with
  // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
  // the same as the thread extent
  Stmt VisitStmt_(const ForNode *op) final {
1163
1164
1165
1166
1167
1168
1169
    ICHECK(thread_iv_.defined());
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (for_node->kind == ForKind::kThreadBinding) {
      ICHECK(for_node->thread_binding.defined());
      String thread_tag = for_node->thread_binding.value()->thread_tag;
      ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
      Var thread_iv = Downcast<Var>(for_node->loop_var);
1170
1171
      Stmt new_body =
          ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
1172
1173
1174
1175
1176
      return new_body;
    }
    return for_node;
  }

1177
1178
1179
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
1180
1181
1182
1183
1184
1185
    if (!thread_iv_.defined()) {
      return block_realize;
    }

    Block block = block_realize->block;
    WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
1186
    marker.Prepare(block);
1187
1188
1189
1190
1191
1192
    marker(block);
    if (!marker.HasProducer()) {
      // Cannot detect any producer here, directly return.
      return block_realize;
    }

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
    if (disable_warp_specialized_) {
      WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
                                     marker, true);
      auto code = mbarrier_emitter(block->body);
      int num_barriers = mbarrier_emitter.num_barriers_;
      Array<PrimExpr> barrier_num_threads;
      barrier_num_threads.reserve(num_barriers);
      PrimExpr arrive_thread_count = thread_iv_->dom->extent;
      for (int i = 0; i < num_barriers; i++) {
        barrier_num_threads.push_back(arrive_thread_count);
      }
      Stmt init_barrier = Evaluate(Call(
          DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
      block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
      block_realize.CopyOnWrite()->block = block;
      return block_realize;
    }
1210
1211
1212
1213
1214
1215
1216
    WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
    WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
    Stmt producer_code = producer(block->body);
    Stmt consumer_code = consumer(block->body);
    PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
    PrimExpr producer_thread_extent = thread_iv_->dom->extent;
    // Need one warp-group for bulk-copy only case
1217
1218
    if (!marker.HasSimtCopy())
      producer_thread_extent = 128;
1219
1220

    // TODO: estimate the correct reg usage.
1221
1222
1223
    int dec_reg = nreg_[0].as<IntImmNode>()->value;
    int inc_reg = nreg_[1].as<IntImmNode>()->value;

1224
1225
1226
1227
1228
1229
1230
1231
    auto inc_reg_stmt = Evaluate(0);
    auto dec_reg_stmt = Evaluate(0);
    if (dec_reg >= 0 && inc_reg >= 0) {
      inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
                                   {inc_reg == 0 ? 240 : inc_reg, 1}));
      dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
                                   {dec_reg == 0 ? 24 : dec_reg, 0}));
    }
1232
1233
1234
1235

    producer_code = SeqStmt({dec_reg_stmt, producer_code});
    consumer_code = SeqStmt({inc_reg_stmt, consumer_code});

1236
1237
1238
    producer_code =
        ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
                                   thread_iv_->var - consumer_thread_extent);
1239
1240
1241
1242
1243
1244
1245
1246
1247
    updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
    need_update_thread_extent_ = true;

    ICHECK(producer.num_barriers_ == consumer.num_barriers_)
        << producer.num_barriers_ << " " << consumer.num_barriers_;
    int num_barriers = consumer.num_barriers_;
    Array<PrimExpr> barrier_num_threads;
    barrier_num_threads.reserve(num_barriers);
    for (int i = 0; i < num_barriers; i++) {
1248
1249
1250
      PrimExpr arrive_thread_count = producer.released_barrier_.count(i)
                                         ? producer_thread_extent
                                         : consumer_thread_extent;
1251
1252
1253
      barrier_num_threads.push_back(arrive_thread_count);
    }

1254
1255
1256
1257
    Stmt init_barrier = Evaluate(Call(
        DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
    Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
                           producer_code, consumer_code);
1258
    // Add an attr here to handle the partial thread count in ThreadSync pass.
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
                                  Downcast<IntImm>(consumer_thread_extent)};
    body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body);

    block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

  WarpSpecializedRewriter() = default;

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
1276
  bool disable_warp_specialized_ = false;
1277
  Array<IntImm> nreg_;
1278
1279
1280
1281
1282
1283
};

using namespace tir::transform;

tvm::transform::Pass WarpSpecialized() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1284
1285
1286
    bool disable_warp_specialized =
        ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
    return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
1287
1288
1289
1290
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}

1291
1292
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized")
    .set_body_typed(WarpSpecialized);
1293

1294
1295
} // namespace tl
} // namespace tvm