annotate_read_only_params.cc 6.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
/*!
 * \file annotate_read_only_params.cc
 * \brief Annotate PrimFunc parameters that are read-only (never written).
 */

#include <string>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>

namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;

/*!
 * \brief A simple visitor that marks handle parameters as written when they
 *        appear on the LHS of a BufferStore or in a tvm_access_ptr with write
 * flag.
 */
class ReadWriteMarker : public StmtExprVisitor {
public:
  explicit ReadWriteMarker(
      const std::unordered_set<const VarNode *> &param_or_data_vars)
      : param_or_data_vars_(param_or_data_vars) {}

  const std::unordered_set<const VarNode *> &written() const {
    return written_;
  }

  // Try to resolve the underlying buffer data Var from a pointer-like
  // argument. Supports:
  //  - address_of(BufferLoad(...)) -> returns buffer->data
  //  - BufferLoad(...)             -> returns buffer->data
  // Otherwise returns nullptr.
  const VarNode *ResolveDataVarFromPtrArg(const PrimExpr &arg) const {
    if (const auto *call = arg.as<CallNode>()) {
      if (call->op.same_as(builtin::address_of())) {
        if (call->args.size() == 1U) {
          if (const auto *load = call->args[0].as<BufferLoadNode>()) {
            return load->buffer->data.get();
          }
        }
      }
    } else if (const auto *load = arg.as<BufferLoadNode>()) {
      return load->buffer->data.get();
    }
    return nullptr;
  }

  void VisitStmt_(const BufferStoreNode *op) final {
    const VarNode *data = op->buffer->data.get();
    if (param_or_data_vars_.count(data)) {
      written_.insert(data);
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  void VisitExpr_(const CallNode *op) final {
    // Detect tvm_access_ptr writes. Be conservative if rw_mask is non-constant.
    if (op->op.same_as(builtin::tvm_access_ptr())) {
      if (op->args.size() == 5U) {
        if (const VarNode *buf = op->args[1].as<VarNode>()) {
          const IntImmNode *flag = op->args[4].as<IntImmNode>();
          bool maybe_write = true; // default conservative
          if (flag) {
            maybe_write = (flag->value & 2) != 0; // write bit set
          }
          if (maybe_write && param_or_data_vars_.count(buf)) {
            written_.insert(buf);
          }
        }
      }
    } else {
      // Generic fallback: mark buffers that appear as
      // address_of(BufferLoad(...)) in call arguments as written. This matches
      // patterns like
      //   tl.tma_store(address_of(smem[..]), address_of(gmem[..]), ...)
      //   call_extern("AtomicAdd*", address_of(gmem[..]), ...)
      // and avoids over-marking plain BufferLoad used for reads.
      for (const PrimExpr &a : op->args) {
        if (const auto *c = a.as<CallNode>()) {
          if (c->op.same_as(builtin::address_of()) && c->args.size() == 1U) {
            if (const auto *bl = c->args[0].as<BufferLoadNode>()) {
              const VarNode *data = bl->buffer->data.get();
              if (param_or_data_vars_.count(data)) {
                written_.insert(data);
              }
            }
          }
        }
      }
    }
    StmtExprVisitor::VisitExpr_(op);
  }

private:
  std::unordered_set<const VarNode *> param_or_data_vars_;
  std::unordered_set<const VarNode *> written_;
};

/*!
 * \brief Annotate PrimFunc with indices of read-only handle parameters.
 *
 * Adds an Array<Integer> attribute "tl.readonly_param_indices" that lists
 * parameter indices which correspond to handle parameters that are never
 * written inside the function body. This can be used by codegen to emit
 * `const` qualifiers to enable read-only caching (e.g., __ldg on CUDA).
 */
static tir::PrimFunc MarkReadOnlyParams(tir::PrimFunc f) {
  // Gather handle params and their corresponding buffer data vars (aliases).
  std::unordered_set<const VarNode *> param_or_data_vars;
  // Map back from data var to parameter index for result attribution.
  std::unordered_map<const VarNode *, size_t> data_var_to_param_idx;

  for (size_t i = 0; i < f->params.size(); ++i) {
    const Var &p = f->params[i];
    if (!p->dtype.is_handle())
      continue;
    param_or_data_vars.insert(p.get());
    // If there is a buffer_map entry for this param, include its data var too.
    if (auto opt = f->buffer_map.Get(p)) {
      const VarNode *data = opt.value()->data.get();
      param_or_data_vars.insert(data);
      data_var_to_param_idx[data] = i;
    }
  }
  if (param_or_data_vars.empty())
    return f;

  ReadWriteMarker marker(param_or_data_vars);
  marker(f->body);

  // Determine read-only parameter indices among all params (handle only)
  Array<Integer> readonly_indices;
  for (size_t i = 0; i < f->params.size(); ++i) {
    const Var &v = f->params[i];
    if (!v->dtype.is_handle())
      continue;

    bool is_written = false;
    // Direct param var written?
    if (marker.written().count(v.get())) {
      is_written = true;
    } else {
      // Or any aliased data var written?
      if (auto opt = f->buffer_map.Get(v)) {
        if (marker.written().count(opt.value()->data.get())) {
          is_written = true;
        }
      }
    }

    if (!is_written) {
      readonly_indices.push_back(Integer(static_cast<int>(i)));
    }
  }

  if (!readonly_indices.empty()) {
    Map<String, Any> attrs;
    attrs.Set(String("tl.readonly_param_indices"), readonly_indices);
    f = WithAttrs(std::move(f), attrs);
  }
  return f;
}

namespace transform {
using namespace tir::transform;

Pass AnnotateReadOnlyParams() {
  auto pass_func = [](PrimFunc f, const IRModule &m,
                      const tvm::transform::PassContext &ctx) {
    return MarkReadOnlyParams(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateReadOnlyParams", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.AnnotateReadOnlyParams",
                        AnnotateReadOnlyParams);
}

} // namespace transform
} // namespace tl
} // namespace tvm