"vscode:/vscode.git/clone" did not exist on "c7af5b73feffef1422c6c138280631b04745511e"
cluster_planning.cc 4.07 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file clasuter_planning.cc
 * \brief Plan the cluster for GPU(sm90+) blocks
 */

#include <tvm/arith/analyzer.h>
7
8
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
9
10
11
12
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

13
14
#include "../support/ffi_aliases.h"

15
16
17
18
namespace tvm {
namespace tir {

class ClusterPlanner {
19
20
public:
  static PrimFunc Substitute(PrimFunc &f) {
21
22
    // Step 1: Collect the read region of the function
    Map<Var, Buffer> buffer_data_to_buffer_;
23
    for (const auto &[_, buffer] : f->buffer_map) {
24
25
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
26
27
28
29
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ f->body);
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
30
31
32
33
34
35
36
    auto reads = access[0];

    BlockIdxVisitor blockIdx_visitor;
    blockIdx_visitor(f->body);
    auto dom_map = blockIdx_visitor.dom_map_;

    // Step 2: Collect mem reuse count for clustering on each dimension.
37
38
39
    std::unordered_map<const IterVarNode *, size_t> mem_reuse_count;
    for (auto iv : dom_map)
      mem_reuse_count[iv] = 0;
40

41
    for (const auto &buffer_region : reads) {
42
43
      PrimExpr size = buffer_region->buffer->dtype.bits();
      RegionVisitor visitor;
44
      for (const auto &range : buffer_region->region) {
45
46
47
48
49
50
        size = size * range->extent;
        visitor(range->min);
      }
      size = arith::Analyzer().Simplify(size);
      if (auto imm = size.as<IntImmNode>()) {
        for (auto iv : dom_map) {
51
52
          if (visitor.seen_.count(iv->var.get()) == 0)
            mem_reuse_count[iv] += imm->value;
53
54
55
56
57
58
59
60
61
        }
      }
    }

    // Step 3: Pick the cluster dimension with the largest mem_reuse.
    size_t mem_reuse_max = 0;
    String cluster_tag;
    for (auto iv : dom_map) {
      if (auto extent = iv->dom->extent.as<IntImmNode>()) {
62
63
        if (extent->value % cluster_size_ == 0 &&
            mem_reuse_count[iv] > mem_reuse_max) {
64
65
66
67
68
69
70
          cluster_tag = iv->thread_tag;
          mem_reuse_max = mem_reuse_count[iv];
        }
      }
    }

    if (mem_reuse_max > 0) {
71
72
      std::string tag_str =
          static_cast<std::string>(cluster_tag); // Convert to std::string
73
74
75
76
77
78
79
      if (tag_str.rfind("blockIdx", 0) == 0) {
        // starts with "blockIdx"
        tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
      } else {
        // Unexpected format — maybe just prefix
        tag_str = "clusterIdx" + tag_str;
      }
80
      cluster_tag = String(tag_str); // Convert back
81
82
83
84
85
86
      return WithAttr(f, cluster_tag, Integer(cluster_size_));
    } else {
      return f;
    }
  }

87
private:
88
89
90
  ClusterPlanner() = default;

  class RegionVisitor : public ExprVisitor {
91
  public:
92
    RegionVisitor() {};
93
94
    void VisitExpr_(const VarNode *var) { seen_.insert(var); }
    std::unordered_set<const VarNode *> seen_;
95
96
97
  };

  class BlockIdxVisitor : public StmtVisitor {
98
  public:
99
    BlockIdxVisitor() {};
100
    void VisitStmt_(const AttrStmtNode *attr) final {
101
102
103
104
105
106
107
108
109
      if (attr->attr_key == attr::thread_extent) {
        IterVar iv = Downcast<IterVar>(attr->node);
        String tag = iv->thread_tag;
        if (tag == "blockIdx.x" || tag == "blockIdx.y" || tag == "blockIdx.z")
          dom_map_.insert(iv.get());
      }
      StmtVisitor::VisitStmt_(attr);
    }
    /*! \brief The map from vars to blockidx extents. */
110
    std::unordered_set<const IterVarNode *> dom_map_;
111
112
113
114
115
116
117
118
119
120
121
  };

  /*! \brief Currently set the plossible cluster size as 2 */
  const static int cluster_size_ = 2;
};

PrimFunc ClusterPlanning(PrimFunc f) { return ClusterPlanner::Substitute(f); }

namespace transform {

tvm::transform::Pass ClusterPlanning() {
122
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
123
124
125
126
127
    return ClusterPlanning(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}

128
TVM_FFI_STATIC_INIT_BLOCK() {
129
130
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
131
}
132
} // namespace transform
133

134
135
} // namespace tir
} // namespace tvm