collector.h 1.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*!
 * \file collector.h
 * \brief Collect information from the IR
 */

#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

class ThreadTagChecker : public StmtExprVisitor {
public:
  static bool HasOnlyThreadIdxX(const PrimFunc &f) {
    ThreadTagChecker checker;
    checker(f->body);
    return checker.is_valid_;
  }

29
30
31
32
33
34
  static IterVar GetThreadVar(const Stmt &body) {
    ThreadTagChecker checker;
    checker(body);
    return checker.thread_var_;
  }

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
private:
  void VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iter_var = Downcast<IterVar>(op->node);
      String thread_tag = iter_var->thread_tag;
      bool is_y_or_z =
          thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";

      if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) {
        is_valid_ = false;
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitStmt_(const ForNode *op) final {
    if (op->kind == ForKind::kThreadBinding) {
      ICHECK(op->thread_binding.defined());
      String thread_tag = op->thread_binding.value()->thread_tag;
54
55
56
      if (thread_tag == "threadIdx.x") {
        thread_var_ = Downcast<IterVar>(op->thread_binding);
      }
57
58
59
60
61
62
63
64
65
66
67
68
      bool is_y_or_z =
          thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
      if (!thread_tag.empty() && is_y_or_z) {
        auto iter_var = Downcast<IterVar>(op->thread_binding);
        if (iter_var.defined() && iter_var->dom.defined() &&
            !is_one(iter_var->dom->extent)) {
          is_valid_ = false;
        }
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }
69
  IterVar thread_var_;
70
71
72
73
74
  bool is_valid_ = true;
};

} // namespace tl
} // namespace tvm