"src/vscode:/vscode.git/clone" did not exist on "ea66515560918c118adac84e22215b431bd28e84"
cluster_planning.cc 4 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file clasuter_planning.cc
 * \brief Plan the cluster for GPU(sm90+) blocks
 */

#include <tvm/arith/analyzer.h>
7
8
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
13
14
15
16
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class ClusterPlanner {
17
18
public:
  static PrimFunc Substitute(PrimFunc &f) {
19
20
    // Step 1: Collect the read region of the function
    Map<Var, Buffer> buffer_data_to_buffer_;
21
    for (const auto &[_, buffer] : f->buffer_map) {
22
23
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
24
25
26
27
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ f->body);
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
28
29
30
31
32
33
34
    auto reads = access[0];

    BlockIdxVisitor blockIdx_visitor;
    blockIdx_visitor(f->body);
    auto dom_map = blockIdx_visitor.dom_map_;

    // Step 2: Collect mem reuse count for clustering on each dimension.
35
36
37
    std::unordered_map<const IterVarNode *, size_t> mem_reuse_count;
    for (auto iv : dom_map)
      mem_reuse_count[iv] = 0;
38

39
    for (const auto &buffer_region : reads) {
40
41
      PrimExpr size = buffer_region->buffer->dtype.bits();
      RegionVisitor visitor;
42
      for (const auto &range : buffer_region->region) {
43
44
45
46
47
48
        size = size * range->extent;
        visitor(range->min);
      }
      size = arith::Analyzer().Simplify(size);
      if (auto imm = size.as<IntImmNode>()) {
        for (auto iv : dom_map) {
49
50
          if (visitor.seen_.count(iv->var.get()) == 0)
            mem_reuse_count[iv] += imm->value;
51
52
53
54
55
56
57
58
59
        }
      }
    }

    // Step 3: Pick the cluster dimension with the largest mem_reuse.
    size_t mem_reuse_max = 0;
    String cluster_tag;
    for (auto iv : dom_map) {
      if (auto extent = iv->dom->extent.as<IntImmNode>()) {
60
61
        if (extent->value % cluster_size_ == 0 &&
            mem_reuse_count[iv] > mem_reuse_max) {
62
63
64
65
66
67
68
          cluster_tag = iv->thread_tag;
          mem_reuse_max = mem_reuse_count[iv];
        }
      }
    }

    if (mem_reuse_max > 0) {
69
70
71
72
73
74
75
76
77
      std::string tag_str = cluster_tag; // Convert to std::string
      if (tag_str.rfind("blockIdx", 0) == 0) {
        // starts with "blockIdx"
        tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
      } else {
        // Unexpected format — maybe just prefix
        tag_str = "clusterIdx" + tag_str;
      }
      cluster_tag = tvm::ffi::String(tag_str); // Convert back
78
79
80
81
82
83
      return WithAttr(f, cluster_tag, Integer(cluster_size_));
    } else {
      return f;
    }
  }

84
private:
85
86
87
  ClusterPlanner() = default;

  class RegionVisitor : public ExprVisitor {
88
  public:
89
    RegionVisitor(){};
90
91
    void VisitExpr_(const VarNode *var) { seen_.insert(var); }
    std::unordered_set<const VarNode *> seen_;
92
93
94
  };

  class BlockIdxVisitor : public StmtVisitor {
95
  public:
96
    BlockIdxVisitor(){};
97
    void VisitStmt_(const AttrStmtNode *attr) final {
98
99
100
101
102
103
104
105
106
      if (attr->attr_key == attr::thread_extent) {
        IterVar iv = Downcast<IterVar>(attr->node);
        String tag = iv->thread_tag;
        if (tag == "blockIdx.x" || tag == "blockIdx.y" || tag == "blockIdx.z")
          dom_map_.insert(iv.get());
      }
      StmtVisitor::VisitStmt_(attr);
    }
    /*! \brief The map from vars to blockidx extents. */
107
    std::unordered_set<const IterVarNode *> dom_map_;
108
109
110
111
112
113
114
115
116
117
118
  };

  /*! \brief Currently set the plossible cluster size as 2 */
  const static int cluster_size_ = 2;
};

PrimFunc ClusterPlanning(PrimFunc f) { return ClusterPlanner::Substitute(f); }

namespace transform {

tvm::transform::Pass ClusterPlanning() {
119
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
120
121
122
123
124
    return ClusterPlanning(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}

125
126
127
128
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
});
129
} // namespace transform
130

131
132
} // namespace tir
} // namespace tvm