"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c42aded1beb2d3ad8dfa6d7355184a1bf3cf3316"
annotate_warp_group_reg_alloc.cc 6.08 KB
Newer Older
1
2
3
4
5
/*!
 * \file annotate_warp_group_reg_alloc.cc
 * \brief Annotate warp group reg alloc for warp specialization
 */

6
#include "warp_specialized_rewriter.h"
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <unordered_set>
#include <vector>

namespace tvm {
namespace tl {

using namespace tir;

class SetMaxNRegCollector : public StmtExprVisitor {
public:
  static Array<IntImm> Collect(const PrimFunc &f) {
    SetMaxNRegCollector collector;
    collector(f->body);
20
21
22
    if (collector.warp_specialized_) {
      return Array<IntImm>({});
    }
23
24
25
26
27
28
29
30
31
32
    return collector.has_no_set_max_nreg_
               ? Array<IntImm>({IntImm(DataType::Int(32), -1),
                                IntImm(DataType::Int(32), -1)})
               : collector.nreg_;
  }

private:
  void VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
      if (call->op.same_as(set_max_nreg())) {
33
34
        auto reg_hint = call->args[0].as<IntImmNode>()->value;
        auto is_inc = call->args[1].as<IntImmNode>()->value;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        ICHECK(reg_hint <= 240 && reg_hint >= 24)
            << "Invalid reg hint: " << reg_hint;
        ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;

        // producer should decrease register hint while consumer should increase
        // register hint
        nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
      } else if (call->op.same_as(no_set_max_nreg())) {
        has_no_set_max_nreg_ = true;
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

49
50
51
52
53
54
55
  void VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == attr::kCustomWarpSpecialization) {
      warp_specialized_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

56
57
58
  Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
                      IntImm(DataType::Int(32), 0)};
  bool has_no_set_max_nreg_ = false;
59
  bool warp_specialized_ = false;
60
61
};

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class SimtCopyDetector : public StmtExprVisitor {
public:
  static bool Detect(const Stmt &stmt) {
    SimtCopyDetector detector;
    detector.VisitStmt(stmt);
    return detector.has_simt_copy_;
  }

private:
  void VisitStmt_(const BufferStoreNode *op) final {
    auto scope =
        runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
    if (scope.to_string() != "global") {
      has_simt_copy_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  bool has_simt_copy_{false};
};

83
84
85
86
87
class SetMaxNRegInjector : public StmtExprMutator {
public:
  static PrimFunc Inject(PrimFunc f) {
    auto T = SetMaxNRegInjector();
    T.nreg_ = SetMaxNRegCollector::Collect(f);
88
89
90
    if (T.nreg_.empty()) {
      return f;
    }
91
92
93
94
95
96
97
    f.CopyOnWrite()->body = T(f->body);
    return f;
  }

private:
  Stmt VisitStmt_(const EvaluateNode *op) final {
    if (const CallNode *call = op->value.as<CallNode>()) {
98
      if (call->op.same_as(no_set_max_nreg())) {
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        // Remove the original set_max_nreg calls as they will be re-inserted
        // at appropriate locations
        return Evaluate(0);
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == tir::attr::thread_extent &&
        Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
      thread_iv_ = Downcast<IterVar>(op->node);
      need_update_thread_extent_ = false;
      AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
      if (need_update_thread_extent_) {
        thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
        attr_stmt.CopyOnWrite()->node = thread_iv_;
        attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
      }
      thread_iv_ = {};
      return attr_stmt;
    } else if (op->attr_key == attr::kWarpSpecializationScope) {
      auto if_then_else = Downcast<IfThenElse>(op->body);
      if (!if_then_else.defined()) {
        return StmtExprMutator::VisitStmt_(op);
      }
      auto producer_body = if_then_else->then_case;
      Optional<Stmt> consumer_body = if_then_else->else_case;
      ICHECK(consumer_body.defined()) << "Consumer body is undefined";

129
130
      auto dec_reg = nreg_[0].as<IntImmNode>()->value;
      auto inc_reg = nreg_[1].as<IntImmNode>()->value;
131
132
133
134
135

      auto inc_reg_stmt = Evaluate(0);
      auto dec_reg_stmt = Evaluate(0);

      // Only inject if we have valid register hints and no SIMT copy
136
      bool has_simt_copy = SimtCopyDetector::Detect(producer_body);
137

138
139
140
      if (dec_reg == 0 && inc_reg == 0 && !has_simt_copy) {
        auto inc_reg_num = IntImm(DataType::Int(32), 240);
        auto dec_reg_num = IntImm(DataType::Int(32), 24);
141
142
143
144
        inc_reg_stmt = Evaluate(
            Call(DataType::Handle(), set_max_nreg(), {inc_reg_num, 1}));
        dec_reg_stmt = Evaluate(
            Call(DataType::Handle(), set_max_nreg(), {dec_reg_num, 0}));
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
      }

      // Inject register setting statements
      Array<Stmt> producer_stmts;
      producer_stmts.push_back(dec_reg_stmt);
      producer_stmts.push_back(producer_body);
      auto new_producer_body = SeqStmt(producer_stmts);

      Array<Stmt> consumer_stmts;
      consumer_stmts.push_back(inc_reg_stmt);
      consumer_stmts.push_back(consumer_body.value());
      auto new_consumer_body = SeqStmt(consumer_stmts);

      auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
                                    new_consumer_body);
      auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);

      return new_attr;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

  Array<IntImm> nreg_;
  IterVar thread_iv_;
  Optional<PrimExpr> updated_thread_extent_;
  bool need_update_thread_extent_ = false;
};

using namespace tir::transform;

tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
177
178
179
  auto pass_func = [](PrimFunc f, const IRModule &m,
                      const PassContext &ctx) -> PrimFunc {
    return SetMaxNRegInjector::Inject(std::move(f));
180
181
182
183
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
}

184
TVM_FFI_STATIC_INIT_BLOCK() {
185
186
187
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
                        AnnotateWarpGroupRegAlloc);
188
}
189
190
191

} // namespace tl
} // namespace tvm