annotate_warp_group_reg_alloc.cc 5.8 KB
Newer Older
1
2
3
4
5
/*!
 * \file annotate_warp_group_reg_alloc.cc
 * \brief Annotate warp group reg alloc for warp specialization
 */

6
#include "warp_specialized_rewriter.h"
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <unordered_set>
#include <vector>

namespace tvm {
namespace tl {

using namespace tir;

class SetMaxNRegCollector : public StmtExprVisitor {
public:
  static Array<IntImm> Collect(const PrimFunc &f) {
    SetMaxNRegCollector collector;
    collector(f->body);
20
21
22
    if (collector.warp_specialized_) {
      return Array<IntImm>({});
    }
23
24
25
26
27
28
29
30
31
32
    return collector.has_no_set_max_nreg_
               ? Array<IntImm>({IntImm(DataType::Int(32), -1),
                                IntImm(DataType::Int(32), -1)})
               : collector.nreg_;
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(set_max_nreg())) {
33
34
        auto reg_hint = call->args[0].as<IntImmNode>()->value;
        auto is_inc = call->args[1].as<IntImmNode>()->value;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        ICHECK(reg_hint <= 240 && reg_hint >= 24)
            << "Invalid reg hint: " << reg_hint;
        ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;

        // producer should decrease register hint while consumer should increase
        // register hint
        nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
      } else if (call->op.same_as(no_set_max_nreg())) {
        has_no_set_max_nreg_ = true;
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

49
50
51
52
53
54
55
  void VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == attr::kCustomWarpSpecialization) {
      warp_specialized_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

56
57
58
  Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
                      IntImm(DataType::Int(32), 0)};
  bool has_no_set_max_nreg_ = false;
59
  bool warp_specialized_ = false;
60
61
62
63
64
65
66
};

class SetMaxNRegInjector : public StmtExprMutator {
public:
  static PrimFunc Inject(PrimFunc f) {
    auto T = SetMaxNRegInjector();
    T.nreg_ = SetMaxNRegCollector::Collect(f);
67
68
69
    if (T.nreg_.empty()) {
      return f;
    }
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

private:
  Stmt VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(set_max_nreg()) ||
          call->op.same_as(no_set_max_nreg())) {
        // Remove the original set_max_nreg calls as they will be re-inserted
        // at appropriate locations
        return Evaluate(0);
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else if (op->attr_key == attr::kWarpSpecializationScope) {
      auto if_then_else = Downcast<IfThenElse>(op->body);
      if (!if_then_else.defined()) {
        return StmtExprMutator::VisitStmt_(op);
      }
      auto producer_body = if_then_else->then_case;
      Optional<Stmt> consumer_body = if_then_else->else_case;
      ICHECK(consumer_body.defined()) << "Consumer body is undefined";

109
110
      auto dec_reg = nreg_[0].as<IntImmNode>()->value;
      auto inc_reg = nreg_[1].as<IntImmNode>()->value;
111
112
113
114
115
116
117
118
119
120

      auto inc_reg_stmt = Evaluate(0);
      auto dec_reg_stmt = Evaluate(0);

      // Only inject if we have valid register hints and no SIMT copy
      // For now, we assume no SIMT copy detection is available here
      // TODO: Add SIMT copy detection if needed
      bool has_simt_copy = false; // Placeholder

      if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
121
122
123
124
125
126
127
128
        auto inc_reg_num =
            IntImm(DataType::Int(32), inc_reg == 0 ? 240 : inc_reg);
        auto dec_reg_num =
            IntImm(DataType::Int(32), dec_reg == 0 ? 24 : dec_reg);
        inc_reg_stmt = Evaluate(
            Call(DataType::Handle(), set_max_nreg(), {inc_reg_num, 1}));
        dec_reg_stmt = Evaluate(
            Call(DataType::Handle(), set_max_nreg(), {dec_reg_num, 0}));
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
      }

      // Inject register setting statements
      Array<Stmt> producer_stmts;
      producer_stmts.push_back(dec_reg_stmt);
      producer_stmts.push_back(producer_body);
      auto new_producer_body = SeqStmt(producer_stmts);

      Array<Stmt> consumer_stmts;
      consumer_stmts.push_back(inc_reg_stmt);
      consumer_stmts.push_back(consumer_body.value());
      auto new_consumer_body = SeqStmt(consumer_stmts);

      auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
                                    new_consumer_body);
      auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);

      return new_attr;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

  Array<IntImm> nreg_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
};

using namespace tir::transform;

tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
161
162
163
  auto pass_func = [](PrimFunc f, const IRModule &m,
                      const PassContext &ctx) -> PrimFunc {
    return SetMaxNRegInjector::Inject(std::move(f));
164
165
166
167
168
169
170
171
172
173
174
175
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
}

TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
                        AnnotateWarpGroupRegAlloc);
});

} // namespace tl
} // namespace tvm