reduce.h 5.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 * \file tl/op/reduce.h
 * \brief Define reduce operator.
 *
 */

#ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_

10
#include "operator.h"
11
12

namespace tvm {
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/**
 * Tile operator node that performs a reduction (sum, max, min, etc.) along a
 * single tensor dimension.
 *
 * Represents a per-instance reduce operator with explicit source/destination
 * buffers, target dimension, reduction type, and a flag controlling whether the
 * destination is cleared before reduction.
 */

/**
 * Lower this ReduceOpNode into a Tir Stmt suitable for code generation.
 *
 * Produces the TIR statement(s) that implement the configured reduction.
 *
 * @return A TIR `Stmt` implementing the reduce operation.
 */

/**
 * Infer input/output layouts for this reduce operator.
 *
 * Returns a LayoutMap describing how input and output buffer layouts relate
 * for the configured reduction dimension.
 *
 * @param level Inference detail level that may affect how aggressively layouts
 * are inferred.
 * @return A LayoutMap mapping operator arguments to inferred layouts.
 */

/**
 * Retrieve the global operator descriptor for the reduce operator.
 *
 * @return A reference to the Op descriptor corresponding to this operator type.
 */

/**
 * Create a copy of this reduce operator as a TileOperator handle.
 *
 * The returned TileOperator preserves the node's configuration (buffers, dim,
 * type, clear).
 *
 * @return A TileOperator wrapping a cloned ReduceOpNode.
 */

/**
 * Construct the initial value used by the reduction (e.g., 0 for sum, -inf for
 * max).
 *
 * @return A PrimExpr representing the reduction's identity/init value.
 */

/**
 * Combine two partial values according to the configured reduction.
 *
 * Implements the binary reducer (for example, `a + b` for sum or `max(a, b)`
 * for max).
 *
 * @return A PrimExpr representing the reduced result of `a` and `b`.
 */

/**
 * Generate a string snippet suitable for code generation of the reducer
 * expression.
 *
 * The returned code fragment should implement the binary reduction operation in
 * the target backend's code string form.
 *
 * @return A std::string containing the codegen expression for the reducer.
 */

/**
 * Reference wrapper for ReduceOpNode as a TileOperator.
 *
 * Construct a ReduceOp from explicit arguments and a buffer map.
 */

/**
 * Construct a ReduceOp TileOperator from operator arguments and a buffer
 * mapping.
 *
 * @param args Operator arguments (typically shapes, axes, or other prim exprs).
 * @param vmap Mapping from argument names to tir::Buffer instances used by the
 * operator.
 */

/**
 * Tile operator node that computes a cumulative sum along a single tensor
 * dimension.
 *
 * Contains source/destination buffers, the target dimension, and a flag to
 * compute the cumulative sum in reverse order.
 */

/**
 * Lower this CumSumOpNode into a Tir Stmt suitable for code generation.
 *
 * Produces the TIR statement(s) that implement the configured cumulative-sum.
 *
 * @return A TIR `Stmt` implementing the cum-sum operation.
 */

/**
 * Infer input/output layouts for this cumulative-sum operator.
 *
 * Returns a LayoutMap describing how input and output buffer layouts relate
 * for the configured cumulative-sum dimension.
 *
 * @param level Inference detail level that may affect how aggressively layouts
 * are inferred.
 * @return A LayoutMap mapping operator arguments to inferred layouts.
 */

/**
 * Retrieve the global operator descriptor for the cumulative-sum operator.
 *
 * @return A reference to the Op descriptor corresponding to this operator type.
 */

/**
 * Create a copy of this cum-sum operator as a TileOperator handle.
 *
 * The returned TileOperator preserves the node's configuration (buffers, dim,
 * reverse).
 *
 * @return A TileOperator wrapping a cloned CumSumOpNode.
 */

/**
 * Reference wrapper for CumSumOpNode as a TileOperator.
 *
 * Construct a CumSumOp from explicit arguments and a buffer map.
 */

/**
 * Construct a CumSumOp TileOperator from operator arguments and a buffer
 * mapping.
 *
 * @param args Operator arguments (typically shapes, axes, or other prim exprs).
 * @param vmap Mapping from argument names to tir::Buffer instances used by the
 * operator.
 */
153
154
155
156
namespace tl {

using namespace tir;

157
enum class ReduceType : uint8_t {
158
159
160
161
162
163
  kSum,
  kAbsSum,
  kMax,
  kMin,
  kAbsMax,
};
164

165
166
class ReduceOpNode : public TileOperatorNode {
public:
167
168
  tir::Buffer src, dst;
  int dim;
169
  ReduceType type;
170
171
  bool clear;

172
173
174
175
176
177
178
179
180
181
  static constexpr const char *_type_key = "tl.ReduceOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;

private:
182
  PrimExpr MakeInitValue() const;
183
  PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
184
185
186
  std::string MakeCodegenReducer() const;
};

187
class ReduceOp : public TileOperator {
188
public:
189
190
  TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
  TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
191
  static const Op &Get();
192
};
193

194
195
class CumSumOpNode : public TileOperatorNode {
public:
196
197
198
  tir::Buffer src, dst;
  int dim;
  bool reverse;
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  static constexpr const char *_type_key = "tl.CumSumOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;
};

class CumSumOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
  TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
214
215
};

216
217
} // namespace tl
} // namespace tvm
218

219
#endif //  TVM_TL_OP_REDUCE_H_