operator.h 5.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/op.h
 * \brief Tile library operations.
 *
 */

#ifndef TVM_TL_OP_OP_H_
#define TVM_TL_OP_OP_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
14
#include <tvm/tir/op_attr_types.h>
15
#include <tvm/tir/stmt.h>
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

#include "../layout/layout.h"

namespace tvm {
namespace tl {

using namespace tir;

using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>;

enum class InferLevel {
  kFree = 0,
  kCommon = 1,
  kStrict = 2,
};

struct LowerArgs {
  Target target;
36
  Range thread_bounds;
37
38
39
40
  Var thread_var;
  AddWorkspaceCallback AddWorkspace;
  LayoutMap layout_map;
  Map<Buffer, Buffer> buffer_remap;
41
  Array<Var> buffer_var_gemm;
42
43
44
45
};

struct LayoutInferArgs {
  Target target;
46
  Range thread_bounds;
47
48
49
50
  LayoutMap layout_map;
  Map<Buffer, Buffer> buffer_remap;
};

51
52
class TileOperatorNode;
class TileOperator;
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/**
 * Abstract base class for tile-level operators.
 *
 * Implementations must provide lowering to TIR, layout inference, and cloning.
 */

/**
 * Lower this tile operator to a TIR statement.
 *
 * @param T Lowering context and utilities (target, thread bounds, layout
 * mappings, buffer remapping, and AddWorkspace callback for requesting
 * temporary buffers).
 * @param analyzer Arithmetic analyzer used during lowering.
 * @return A TIR Stmt representing the lowered operator.
 */

/**
 * Infer buffer layouts for this operator.
 *
 * The returned LayoutMap associates input/output Buffers with inferred Layouts.
 * The `level` controls how strictly layouts are determined (kFree, kCommon,
 * kStrict).
 *
 * @param T Layout inference context (target, thread bounds, existing
 * layout_map, buffer_remap).
 * @param level Inference strictness level.
 * @return A LayoutMap mapping Buffers to their inferred Layouts.
 */

/**
 * Create a deep copy of this TileOperator.
 *
 * @return A TileOperator referencing a cloned operator instance.
 */

/**
 * Reference wrapper for TileOperatorNode.
 *
 * Use this ObjectRef to hold and pass tile operator instances within the
 * runtime.
 */

/**
 * Extract the underlying Var from an access pointer expression.
 *
 * If `expr` represents an access pointer that directly refers to a variable,
 * returns that Var; otherwise returns a null/default Var.
 *
 * @param expr The pointer/access expression to inspect.
 * @return The extracted Var, or a null Var if none can be found.
 */

/**
 * Parse a Call into a TileOperator using the provided buffer mapping.
 *
 * @param call The Call node representing a tile operator invocation.
 * @param vmap Mapping from TIR Vars to Buffers for resolving buffer arguments.
 * @return A TileOperator constructed from the call and buffer map.
 */

/**
 * Parse a Stmt into a TileOperator using the provided buffer mapping.
 *
 * @param stmt The Stmt representing a tile operator region or call.
 * @param vmap Mapping from TIR Vars to Buffers for resolving buffer references.
 * @return A TileOperator constructed from the statement and buffer map.
 */

/**
 * Function type for TL operator builders exposed to the FFI.
 *
 * Builder functions take an array of PrimExpr arguments and a BufferMap, and
 * return a constructed TileOperator.
 */

/**
 * Register a TL operator and its builder with TVM's op registry.
 *
 * Entry should be a type providing a static `Get()` and a constructor taking
 * `(Array<PrimExpr>, BufferMap)`. This macro registers the operator under the
 * name "tl.OpName" and sets an FFI builder attribute that constructs
 * Entry(args, vmap).
 *
 * Usage: TIR_REGISTER_TL_OP(MyOpEntry, MyOp)
 */
class TileOperatorNode : public Object {
public:
141
  virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0;
142

143
  virtual LayoutMap InferLayout(const LayoutInferArgs &T,
144
                                InferLevel level) const = 0;
145

146
147
  virtual TileOperator Clone() const = 0;

148
  static constexpr const char *_type_key = "tl.TileOperator";
149
150
151

  TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
};
152

153
class TileOperator : public ObjectRef {
154
155
public:
  TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode);
156
157
};

158
Var GetVarFromAccessPtr(const PrimExpr &expr);
159

160
161
162
TileOperator ParseOperator(Call call, BufferMap vmap);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap);

163
164
using OpBuilderFunc =
    ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
165
166
167
168
169
170
171
172
173

#define TIR_REGISTER_TL_OP(Entry, OpName)                                      \
  const Op &Entry::Get() {                                                     \
    static const Op &op = Op::Get("tl." #OpName);                              \
    return op;                                                                 \
  }                                                                            \
  TVM_REGISTER_OP("tl." #OpName)                                               \
      .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)             \
      .set_attr<OpBuilderFunc>("TLOpBuilder",                                  \
174
175
                               [](Array<PrimExpr> args, BufferMap vmap) {      \
                                 return Entry(args, vmap);                     \
176
177
                               })

178
179
} // namespace tl
} // namespace tvm
180

181
#endif // TVM_TL_OP_OP_H_