hoist_nonrestrict_params.cc 3.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
 * Hoist tl.non_restrict_params block annotation(s) to PrimFunc attribute.
 *
 * Previously, we only looked at the root block. This version recursively
 * scans all blocks, unions any tl.non_restrict_params entries it finds,
 * merges with any existing PrimFunc-level attribute, then writes the
 * deduplicated result back to the PrimFunc attrs. This makes annotation
 * placement within the function body flexible for frontends.
 */
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"

namespace tvm {
namespace tl {
using namespace tvm::tir;

class NonRestrictCollector : public StmtVisitor {
public:
  void Collect(const Stmt &stmt) { VisitStmt(stmt); }

  Array<Var> Result() const {
    Array<Var> out;
    out.reserve(collected_.size());
    for (const Var &v : collected_)
      out.push_back(v);
    return out;
  }

private:
  static std::string NormalizeName(const std::string &s) {
    if (s.size() >= 8 && s.rfind("_handle") == s.size() - 7) {
      return s.substr(0, s.size() - 7);
    }
    return s;
  }

  void MaybeInsert(const Var &v) {
    if (!v.defined())
      return;
    const VarNode *p = v.get();
    if (seen_ptr_.count(p))
      return;
    // Also dedup by normalized name to be robust w.r.t recreated Vars
    std::string norm = NormalizeName(v->name_hint);
    if (seen_name_.count(norm))
      return;
    seen_ptr_.insert(p);
    seen_name_.insert(std::move(norm));
    collected_.push_back(v);
  }

  void VisitStmt_(const BlockNode *op) final {
    auto it = op->annotations.find(attr::kNonRestrictParams);
    if (it != op->annotations.end()) {
      if (const auto *arr = (*it).second.as<ffi::ArrayObj>()) {
        // Downcast directly to Array<Var> for convenience
        Array<Var> vars = tvm::Downcast<Array<Var>>((*it).second);
        for (const Var &v : vars) {
          MaybeInsert(v);
        }
      }
    }
    // Recurse into child statements
    StmtVisitor::VisitStmt_(op);
  }

  std::vector<Var> collected_;
  std::unordered_set<const VarNode *> seen_ptr_;
  std::unordered_set<std::string> seen_name_;
};

static PrimFunc HoistNonRestrictParams(PrimFunc f) {
  if (!f.defined())
    return f;

  NonRestrictCollector collector;
  collector.Collect(f->body);
  Array<Var> from_blocks = collector.Result();

  // Merge with any existing PrimFunc-level attribute if present
  if (auto opt_existing = f->GetAttr<Array<Var>>(attr::kNonRestrictParams)) {
    for (const Var &v : opt_existing.value()) {
      // Reuse the collector's dedup logic by temporarily constructing a new
      // collector Alternatively, do a small inline dedup mirroring MaybeInsert
      // Here we inline a simplified pointer-based dedup plus name-based
      // fallback
      bool exists = false;
      for (const Var &cur : from_blocks) {
        if (cur.get() == v.get() || cur->name_hint == v->name_hint) {
          exists = true;
          break;
        }
      }
      if (!exists)
        from_blocks.push_back(v);
    }
  }

  if (from_blocks.empty())
    return f;

  return WithAttr(std::move(f), attr::kNonRestrictParams,
                  std::move(from_blocks));
}

namespace transform {

tvm::transform::Pass HoistNonRestrictParams() {
  auto pass_func = [](PrimFunc f, const IRModule &,
                      const tvm::transform::PassContext &) {
    return tvm::tl::HoistNonRestrictParams(std::move(f));
  };
  return tvm::tir::transform::CreatePrimFuncPass(
      pass_func, 0, "tl.HoistNonRestrictParams", {});
}

} // namespace transform

} // namespace tl
} // namespace tvm

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.HoistNonRestrictParams",
                        tvm::tl::transform::HoistNonRestrictParams);
}