copy.h 13 KB
Newer Older
1
/*!
2
3
 * \file tl/op/copy.h
 * \brief Copy operations and Tensor Memory Access (TMA) descriptors
4
5
6
7
8
 */

#ifndef TVM_TL_OP_COPY_H_
#define TVM_TL_OP_COPY_H_

9
#include "operator.h"
10
11
12
13
14
15
#include "parallel.h"

namespace tvm {
namespace tl {
using namespace tir;

16
/// Copy instruction types for different memory access patterns
17
enum class CopyInst : uint8_t {
18
19
20
21
22
23
24
25
26
  kNormal = 0,    // utilize ldg/stg or cpasync or any buffer copy
  kLDSM = 1,      // ldmatrix memory copy
  kSTSM = 2,      // stmatrix memory copy
  kBulkLoad = 3,  // utilize tma load
  kBulkStore = 4, // utilize tma store
  // we should separate the bulk load and store for 1d and multi-dim
  // as they have different memory access patterns
  kBulkLoad1D = 5,  // utilize tma load 1d
  kBulkStore1D = 6, // utilize tma store 1d
27
28
  kTMemLoad = 7,    // tcgen05.ld (tensor memory -> register)
  kTMemStore = 8,   // tcgen05.st (register -> tensor memory)
29
30
};

31
/// Descriptor for Tensor Memory Access (TMA) copy operations
32
struct TMADesc {
33
34
35
36
37
38
39
40
41
42
43
44
45
  size_t rank;                   ///< Tensor rank (number of dimensions)
  int data_type;                 ///< Data type identifier
  Array<PrimExpr> global_shape;  ///< Shape in global memory
  Array<PrimExpr> global_stride; ///< Strides in global memory
  Array<PrimExpr> smem_box;      ///< Block shape in shared memory
  Array<PrimExpr> smem_stride;   ///< Strides in shared memory
  PrimExpr global_addr;          ///< Base address in global memory
  int swizzle;                   ///< Memory layout swizzle parameter
  int interleave;                ///< Memory interleave parameter
  int oob_fill;                  ///< Out-of-bound fill policy
  int l2_promotion;              ///< L2 cache promotion flag

  /// Encode descriptor fields into runtime call arguments
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  Array<PrimExpr> EncodeCallArgs() const;
};

/*!
 * \brief Descriptor for TMA-based im2col transformation used in Conv2D.
 *
 * This supports extracting patches from the input image (im2col)
 * for convolution lowering, storing them in shared memory.
 */
struct TMAIm2ColDesc {
  size_t rank;                   // Rank of the tensor
  int data_type;                 // Data type identifier
  Array<PrimExpr> global_shape;  // Shape of input tensor in global memory
  Array<PrimExpr> global_stride; // Stride in global memory
  Array<PrimExpr> elem_stride;   // Stride at element level (per axis)
  Array<PrimExpr> lower_corner; // Lower bound offsets for the extraction window
                                // (rank - 2 dims)
  Array<PrimExpr> upper_corner; // Upper bound offsets for the extraction window
                                // (rank - 2 dims)
  PrimExpr global_addr;         // Base address in global memory
  int smem_box_pixel;           // Pixel dimension of shared memory box
  int smem_box_channel;         // Channel dimension of shared memory box
  int swizzle;                  // Memory swizzle setting
  int interleave;               // Memory interleaving setting
  int oob_fill;                 // Out-of-bound fill policy
  int l2_promotion;             // Whether to enable L2 cache promotion

  /*!
   * \brief Encode descriptor fields into runtime arguments.
   */
  Array<PrimExpr> EncodeCallArgs() const;
};

79
80
81
82
83
84
85
86
/*!
 * \brief Get TVM Op handle for Conv2DIm2Col.
 */

/*!
 * \brief Clone this Conv2DIm2Col operator.
 *
 * Returns a TileOperator reference that is a shallow clone of this operator.
87
 */
88
class CopyNode : public TileOperatorNode {
89
public:
90
91
92
93
94
95
96
  Buffer src, dst;                   // Source and destination buffers
  Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
  IntImm coalesced_width; // Width (in elements) for coalesced memory access
  Bool disable_tma = Bool(false); // Whether to disable TMA acceleration

  mutable ParallelOp par_op_; // Optional associated parallelization operator

97
  enum class EvictionPolicy : uint8_t {
98
99
100
101
102
    kEvictNormal = 0,
    kEvictFirst = 1,
    kEvictLast = 2,
  };

103
  uint8_t eviction_policy; // Policy for cache eviction
104
105
  static constexpr const char *_type_key = "tl.Copy";
  TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<CopyNode>()
        .def_ro("src", &CopyNode::src)
        .def_ro("dst", &CopyNode::dst)
        .def_ro("src_range", &CopyNode::src_range)
        .def_ro("dst_range", &CopyNode::dst_range)
        .def_ro("coalesced_width", &CopyNode::coalesced_width);
  }

  bool SEqualReduce(const CopyNode *other, SEqualReducer equal) const {
    return equal(src, other->src) && equal(dst, other->dst) &&
           equal(src_range, other->src_range) &&
           equal(dst_range, other->dst_range) &&
           equal(coalesced_width, other->coalesced_width);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(src);
    hash_reduce(dst);
    hash_reduce(src_range);
    hash_reduce(dst_range);
    hash_reduce(coalesced_width);
  }
  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

134
135
136
137
138
  /*!
   * \brief Lower the copy operator to a TIR statement.
   * \param T        Arguments for lowering.
   * \param analyzer Analyzer for simplification and bounds checks.
   */
139
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
140
141
142
143
144
145

  /*!
   * \brief Infer buffer layouts after applying this operator.
   * \param T     Arguments for layout inference.
   * \param level Level of inference (basic or detailed).
   */
146
147
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
148
149
150
151

  /*!
   * \brief Check if bulk copy is supported.
   */
152
153
  bool CheckBulkLoad(Target target, arith::Analyzer *analyzer,
                     bool check_last_dim = true) const;
154
155
156
157

  /*!
   * \brief Check if bulk store is supported.
   */
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  bool CheckBulkStore(Target target, arith::Analyzer *analyzer,
                      bool check_last_dim = true) const;

  /*!
   * \brief Check if bulk copy 1d load is supported.
   */
  bool CheckBulkLoad1D(Target target, const LayoutMap &layout_map,
                       arith::Analyzer *analyzer) const;

  /*!
   * \brief Check if bulk copy 1d store is supported.
   */
  bool CheckBulkStore1D(Target target, const LayoutMap &layout_map,
                        arith::Analyzer *analyzer) const;

  /*!
   * \brief Check if bulk copy 1d is supported.
   */
  bool CheckBulkCopy1D(const Buffer &global_tensor, const Buffer &shared_tensor,
                       const Array<Range> &global_range,
                       const Array<Range> &shared_range,
                       const LayoutMap &layout_map,
                       arith::Analyzer *analyzer) const;
181
182
183
184
185
186
187
188
189
190
191

  /*!
   * \brief Check if lds memory copy is supported.
   */
  bool CheckLDSMCopy(Target target) const;

  /*!
   * \brief Check if stsm memory copy is supported.
   */
  bool CheckSTSMCopy(Target target) const;

192
193
194
195
196
197
198
199
200
201
  /*!
   * \brief Check if tensor memory load is supported.
   */
  bool CheckTMemLoad(Target target) const;

  /*!
   * \brief Check if tensor memory store is supported.
   */
  bool CheckTMemStore(Target target) const;

202
203
204
  /*!
   * \brief Get the copy instruction type.
   */
205
206
207
  CopyInst GetCopyInst(Target target, bool disable_tma_lower,
                       const LayoutMap &layout_map, arith::Analyzer *analyzer,
                       bool buffer_oob) const;
208
209
210
211
212
213
214
215

protected:
  /*!
   * \brief Generate lowering for bulk/global-to-shared copy.
   */
  Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
                     CopyInst copy_inst) const;

216
217
218
219
220
221
  /*!
   * \brief Generate lowering for bulk copy 1d.
   */
  Stmt LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
                       CopyInst copy_inst) const;

222
223
224
225
226
227
  /*!
   * \brief Generate lowering for LDS Memory Copy (shared memory to shared
   * memory or smem usage).
   */
  Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
                     CopyInst copy_inst) const;
228
229
230
231
232

  /*!
   * \brief Generate lowering for tensor memory copy (tcgen05.ld/st/cp).
   */
  Stmt LowerTmemCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

  /*!
   * \brief Generate lowering for normal copy.
   */
  Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;

  /*!
   * \brief Generate SIMT (thread-level) loop for copying.
   */
  For MakeSIMTLoop(arith::Analyzer *analyzer) const;

  /*!
   * \brief Compute linear layout for tma copy.
   */
  Layout ComputeLinearLayout(const Buffer &shared_tensor) const;

  /*!
   * \brief Create iterator variables for multi-dimensional copy loops.
   */
  Array<IterVar> MakeIterVars() const;

  /*!
   * \brief Calculate source or destination indices from iteration vars.
   * \param ivs      Iterator variables from MakeIterVars().
   * \param src_dst  0 = make source indices, 1 = make destination indices.
   */
  Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;

  /*!
   * \brief Construct the boundary predicate for valid copy (to avoid OOB).
   * \param analyzer  Arithmetic analyser for simplification.
   * \param ivs       Iterator variables.
   * \param extents   Extent expressions for the relevant buffer.
   * \param src_dst   0 = predicate for source, 1 = predicate for destination.
   */
  PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
                         Array<PrimExpr> extents, int src_dst) const;

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
  /**
   * \brief Create a deep copy of this operator.
   *
   * Returns a TileOperator that is a copy of the current node, preserving all
   * configuration (buffers, parameters, and layout-related fields).
   * @return A TileOperator owning the cloned operator node.
   */

  /**
   * \brief Constructor.
   * \param args Expression arguments for the Conv2D im2col operator.
   * \param vmap Buffer variable mapping.
   */

  /**
   * \brief Get the TVM Op handle corresponding to this Conv2DIm2Col operator.
   * @return Reference to the singleton TVM Op representing this operator.
   */
289
290
  TileOperator Clone() const;
};
291

292
293
294
class Copy : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode);
295

296
297
298
299
300
301
  /*!
   * \brief Constructor.
   * \param args  Expression arguments for the copy.
   * \param vmap  Buffer variable mapping.
   */
  TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap);
302

303
304
305
306
  /*!
   * \brief Get the TVM Op handle corresponding to this Copy op.
   */
  static const Op &Get();
307
308
309
310
311
312
313
314
};

/*!
 * \brief Special operator for Conv2D im2col transformation.
 *
 * This operator converts input image layout into columnar format suitable
 * for matrix multiplication-based convolution lowering.
 */
315
class Conv2DIm2ColOpNode : public TileOperatorNode {
316
public:
317
318
319
320
321
322
323
324
325
326
327
  Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
  int stride;      // Stride for convolution
  int padding;     // Padding amount
  int dilation;    // Dilation factor
  int kernel;      // Kernel size
  int eviction_policy; // Cache eviction policy
  PrimExpr nhw_step;   // Step size in NHW dimensions
  PrimExpr c_step;     // Step size in channel dimension

  static constexpr const char *_type_key = "tl.Conv2DIm2Col";
  TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
328

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<Conv2DIm2ColOpNode>()
        .def_ro("src", &Conv2DIm2ColOpNode::src)
        .def_ro("dst", &Conv2DIm2ColOpNode::dst)
        .def_ro("stride", &Conv2DIm2ColOpNode::stride)
        .def_ro("padding", &Conv2DIm2ColOpNode::padding)
        .def_ro("dilation", &Conv2DIm2ColOpNode::dilation)
        .def_ro("kernel", &Conv2DIm2ColOpNode::kernel)
        .def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
  }

  bool SEqualReduce(const Conv2DIm2ColOpNode *other,
                    SEqualReducer equal) const {
    return equal(src, other->src) && equal(dst, other->dst) &&
           equal(stride, other->stride) && equal(padding, other->padding) &&
           equal(dilation, other->dilation) && equal(kernel, other->kernel) &&
           equal(eviction_policy, other->eviction_policy);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(src);
    hash_reduce(dst);
    hash_reduce(stride);
    hash_reduce(padding);
    hash_reduce(dilation);
    hash_reduce(kernel);
    hash_reduce(eviction_policy);
  }
  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

361
362
363
  /*!
   * \brief Lower to TIR statement.
   */
364
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
365
366

  /*!
367
   * \brief Infer layout for this operator.
368
   */
369
370
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
371
372

  /*!
373
   * \brief Get TVM Op handle.
374
   */
375
376
377
  static const Op &Get();
  TileOperator Clone() const;
};
378

379
380
381
382
383
384
class Conv2DIm2ColOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator,
                                Conv2DIm2ColOpNode);
  TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
385
386
387
388
389
390
};

} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_COPY_H_