"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0246f32b13fcf585185657b2a23d73a12d14c236"
copy.h 16.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/elem.h
 * \brief Define element-wise and copy-related operators for TVM TensorIR
 * Lowering.
 *
 * This header declares the Copy operator and related operator descriptors
 * such as TMADesc and TMAIm2ColDesc, as well as a Conv2DIm2Col special
 * operator.
 */

#ifndef TVM_TL_OP_COPY_H_
#define TVM_TL_OP_COPY_H_

14
#include "operator.h"
15
16
17
18
19
20
#include "parallel.h"

namespace tvm {
namespace tl {
using namespace tir;

21
22
23
/*!
 * \brief Copy instruction type.
 */
24
enum class CopyInst : uint8_t {
25
26
27
28
29
30
31
  kNormal = 0,    // utilize ldg/stg or cpasync or any buffer copy
  kLDSM = 1,      // ldmatrix memory copy
  kSTSM = 2,      // stmatrix memory copy
  kBulkLoad = 3,  // utilize tma load
  kBulkStore = 4, // utilize tma store
};

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*!
 * \brief Descriptor for Tensor Memory Access (TMA) copy operations.
 *
 * Contains meta-information required to perform global-to-shared memory copy
 * using Tensor Memory Accelerator (TMA) hardware instructions. It is mainly
 * used to describe the shape, strides, and data layout for both source and
 * shared memory buffers.
 */
struct TMADesc {
  size_t rank;                  // Tensor rank (number of dimensions)
  int data_type;                // Data type identifier (numeric code)
  Array<PrimExpr> global_shape; // Shape of the source tensor in global memory
  Array<PrimExpr>
      global_stride;           // Strides of the source tensor in global memory
  Array<PrimExpr> smem_box;    // Block shape in shared memory
  Array<PrimExpr> smem_stride; // Strides in shared memory layout
  PrimExpr global_addr;        // Base address in global memory
  int swizzle;                 // Swizzle parameter for memory layout transform
  int interleave;              // Interleave parameter for optimization
  int oob_fill;                // Out-of-bound fill policy
  int l2_promotion;            // Whether to promote data to L2 cache

  /*!
   * \brief Encode descriptor fields into an argument array for runtime calls.
   */
  Array<PrimExpr> EncodeCallArgs() const;
};

/*!
 * \brief Descriptor for TMA-based im2col transformation used in Conv2D.
 *
 * This supports extracting patches from the input image (im2col)
 * for convolution lowering, storing them in shared memory.
 */
struct TMAIm2ColDesc {
  size_t rank;                   // Rank of the tensor
  int data_type;                 // Data type identifier
  Array<PrimExpr> global_shape;  // Shape of input tensor in global memory
  Array<PrimExpr> global_stride; // Stride in global memory
  Array<PrimExpr> elem_stride;   // Stride at element level (per axis)
  Array<PrimExpr> lower_corner; // Lower bound offsets for the extraction window
                                // (rank - 2 dims)
  Array<PrimExpr> upper_corner; // Upper bound offsets for the extraction window
                                // (rank - 2 dims)
  PrimExpr global_addr;         // Base address in global memory
  int smem_box_pixel;           // Pixel dimension of shared memory box
  int smem_box_channel;         // Channel dimension of shared memory box
  int swizzle;                  // Memory swizzle setting
  int interleave;               // Memory interleaving setting
  int oob_fill;                 // Out-of-bound fill policy
  int l2_promotion;             // Whether to enable L2 cache promotion

  /*!
   * \brief Encode descriptor fields into runtime arguments.
   */
  Array<PrimExpr> EncodeCallArgs() const;
};

/*!
 * \brief Copy operator for transferring data between buffers.
 *
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
 * Performs element- or block-wise copies between `src` and `dst` buffers for
 * TensorIR lowering. The operator supports thread-level parallelization,
 * shared-memory layouts, and hardware-accelerated paths (TMA/LDSM/STMATRIX)
 * when available. Public fields describe the copy ranges and tuning knobs
 * (coalesced width, eviction policy, disable_tma).
 */

/*!
 * \brief Lower the copy operator to a TIR statement.
 *
 * Produces a TIR statement implementing the configured copy (normal, LDSM,
 * STSM, or bulk TMA-based) for the given lowering context.
 *
 * \param T        Lowering arguments that provide buffer bindings and context.
 * \param analyzer Analyzer used for expression simplification and bounds
 * checks. \return         A TIR `Stmt` implementing the copy.
 */

/*!
 * \brief Infer buffer layouts after applying this operator.
 *
 * Computes resulting layouts (shape/stride mappings) for buffers affected by
 * this copy operation.
 *
 * \param T     Arguments for layout inference (buffer maps, shapes).
 * \param level Granularity of inference to perform.
 * \return      A LayoutMap describing inferred layouts.
 */

/*!
 * \brief Check if bulk global->shared copy is supported on the target.
 *
 * Returns true if the target supports bulk (TMA) loads from global memory.
 *
 * \param target Target to query.
 */

/*!
 * \brief Check if bulk shared->global store is supported on the target.
 *
 * Returns true if the target supports bulk (TMA) stores to global memory.
 *
 * \param target Target to query.
 */

/*!
 * \brief Check if LDSM (LDMATRIX) memory-copy is supported on the target.
 *
 * \param target Target to query.
 */

/*!
 * \brief Check if STSM (STMATRIX) memory-copy is supported on the target.
 *
 * \param target Target to query.
 */

/*!
 * \brief Select the copy instruction type to use.
 *
 * Chooses between kNormal, kLDSM, kSTSM, kBulkLoad, and kBulkStore based on
 * the target capabilities and whether TMA lowering is disabled.
 *
 * \param target            Target to query.
 * \param disable_tma_lower When true, force non-TMA copy paths.
 * \return                  The selected CopyInst value.
 */

/*!
 * \brief Clone this copy operator.
 *
 * Returns a TileOperator reference that is a shallow clone of this operator
 * object suitable for further modifications in pass pipelines.
 */

/*!
 * \brief Generate lowering for bulk (global-to-shared or shared-to-global)
 * copy.
 *
 * Implements TMA-based bulk load/store lowering when `copy_inst` indicates a
 * bulk path. The function encodes TMA descriptors and produces calls or
 * loops required by the selected bulk mechanism.
 *
 * \param T         Lowering context.
 * \param analyzer  Analyzer for simplification.
 * \param copy_inst Copy instruction type indicating bulk load/store.
 * \return          A TIR `Stmt` implementing the bulk copy.
 */

/*!
 * \brief Generate lowering for LDS matrix-copy paths (LDMATRIX/STMATRIX).
 *
 * Emits the lowering for LDS-based matrix-copy instructions when the chosen
 * `copy_inst` is an LDSM or STSM variant.
 *
 * \param T         Lowering context.
 * \param analyzer  Analyzer for simplification.
 * \param copy_inst Copy instruction type indicating an LDS matrix path.
 * \return          A TIR `Stmt` implementing the matrix-copy.
 */

/*!
 * \brief Generate lowering for the normal (non-bulk, scalar/vec) copy path.
 *
 * Emits element-wise or vectorized loads/stores using the computed iteration
 * space and predicates to ensure in-bounds accesses.
 *
 * \param T        Lowering context.
 * \param analyzer Analyzer for simplification.
 * \return         A TIR `Stmt` implementing the normal copy.
 */

/*!
 * \brief Generate a SIMT-style thread-level loop for the copy.
 *
 * Produces a `For` loop that distributes copy work across SIMD/warp lanes or
 * CUDA threads according to the operator's iteration strategy.
 *
 * \param analyzer Analyzer for simplification.
 * \return         A `For` loop representing the thread-level iteration.
 */

/*!
 * \brief Compute a linear shared-memory layout suitable for TMA copies.
 *
 * Returns a `Layout` that maps the shared-memory `shared_tensor` into a
 * linearized representation required by bulk/TMA transfers.
 *
 * \param shared_tensor Buffer representing the shared-memory tensor.
 * \return              A `Layout` describing the linearized shared layout.
 */

/*!
 * \brief Create iterator variables for multi-dimensional copy loops.
 *
 * The returned `IterVar` array enumerates the loop indices used to traverse
 * the copy extents in each tensor dimension.
 *
 * \return Array of iterator variables.
 */

/*!
 * \brief Calculate source or destination indices from iteration variables.
 *
 * Converts the iterator variables (from MakeIterVars) into concrete index
 * expressions for either the source image or the destination tensor.
 *
 * \param ivs     Iterator variables returned by MakeIterVars().
 * \param src_dst 0 to produce source indices, 1 to produce destination indices.
 * \return        Array of `PrimExpr` index expressions.
 */

/*!
 * \brief Construct the boundary predicate ensuring in-bounds accesses.
 *
 * Builds a boolean expression that guards loads/stores so they only occur
 * when indices lie within the provided `extents`.
 *
 * \param analyzer Arithmetic analyzer used to simplify predicates.
 * \param ivs      Iterator variables.
 * \param extents  Extent expressions for the target buffer.
 * \param src_dst  0 = predicate for source indices, 1 = predicate for
 * destination. \return         A `PrimExpr` boolean predicate.
 */

/*!
 * \brief Constructor.
 *
 * \param args Expression arguments for the copy (indices, sizes, etc.).
 * \param vmap Buffer variable mapping for source and destination.
 */

/*!
 * \brief Get the TVM Op handle corresponding to this Copy op.
 */

/*!
 * \brief Special operator for Conv2D im2col transformation.
 *
 * Converts an input feature map into an im2col matrix layout used for GEMM-
 * based convolution lowering. Public fields configure kernel geometry,
 * stride/padding/dilation, and cache eviction behavior.
 */

/*!
 * \brief Lower to TIR statement.
 *
 * Emits TIR that performs the im2col extraction from `src` into `dst`
 * according to kernel, stride, padding, and dilation parameters.
 *
 * \param T        Lowering context with buffer bindings.
 * \param analyzer Analyzer for expression simplification and bounds reasoning.
 * \return         A TIR `Stmt` performing the im2col transform.
 */

/*!
 * \brief Infer layout for this operator.
 *
 * Produces the layout mapping for the destination im2col matrix given the
 * source layout and convolution parameters.
 *
 * \param T     Layout inference arguments.
 * \param level Inference granularity level.
 * \return      A LayoutMap with inferred layouts for affected buffers.
 */

/*!
 * \brief Get TVM Op handle for Conv2DIm2Col.
 */

/*!
 * \brief Clone this Conv2DIm2Col operator.
 *
 * Returns a TileOperator reference that is a shallow clone of this operator.
307
 */
308
class CopyNode : public TileOperatorNode {
309
public:
310
311
312
313
314
315
316
  Buffer src, dst;                   // Source and destination buffers
  Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
  IntImm coalesced_width; // Width (in elements) for coalesced memory access
  Bool disable_tma = Bool(false); // Whether to disable TMA acceleration

  mutable ParallelOp par_op_; // Optional associated parallelization operator

317
  enum class EvictionPolicy : uint8_t {
318
319
320
321
322
    kEvictNormal = 0,
    kEvictFirst = 1,
    kEvictLast = 2,
  };

323
  uint8_t eviction_policy; // Policy for cache eviction
324
325
  static constexpr const char *_type_key = "tl.Copy";
  TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
326
327
328
329
330
331

  /*!
   * \brief Lower the copy operator to a TIR statement.
   * \param T        Arguments for lowering.
   * \param analyzer Analyzer for simplification and bounds checks.
   */
332
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
333
334
335
336
337
338

  /*!
   * \brief Infer buffer layouts after applying this operator.
   * \param T     Arguments for layout inference.
   * \param level Level of inference (basic or detailed).
   */
339
  LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

  /*!
   * \brief Check if bulk copy is supported.
   */
  bool CheckBulkLoad(Target target) const;

  /*!
   * \brief Check if bulk store is supported.
   */
  bool CheckBulkStore(Target target) const;

  /*!
   * \brief Check if lds memory copy is supported.
   */
  bool CheckLDSMCopy(Target target) const;

  /*!
   * \brief Check if stsm memory copy is supported.
   */
  bool CheckSTSMCopy(Target target) const;

  /*!
   * \brief Get the copy instruction type.
   */
  CopyInst GetCopyInst(Target target, bool disable_tma_lower) const;

  /*!
   * \brief Clone this copy operator.
   */
protected:
  /*!
   * \brief Generate lowering for bulk/global-to-shared copy.
   */
  Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
                     CopyInst copy_inst) const;

  /*!
   * \brief Generate lowering for LDS Memory Copy (shared memory to shared
   * memory or smem usage).
   */
  Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
                     CopyInst copy_inst) const;

  /*!
   * \brief Generate lowering for normal copy.
   */
  Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;

  /*!
   * \brief Generate SIMT (thread-level) loop for copying.
   */
  For MakeSIMTLoop(arith::Analyzer *analyzer) const;

  /*!
   * \brief Compute linear layout for tma copy.
   */
  Layout ComputeLinearLayout(const Buffer &shared_tensor) const;

  /*!
   * \brief Create iterator variables for multi-dimensional copy loops.
   */
  Array<IterVar> MakeIterVars() const;

  /*!
   * \brief Calculate source or destination indices from iteration vars.
   * \param ivs      Iterator variables from MakeIterVars().
   * \param src_dst  0 = make source indices, 1 = make destination indices.
   */
  Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;

  /*!
   * \brief Construct the boundary predicate for valid copy (to avoid OOB).
   * \param analyzer  Arithmetic analyser for simplification.
   * \param ivs       Iterator variables.
   * \param extents   Extent expressions for the relevant buffer.
   * \param src_dst   0 = predicate for source, 1 = predicate for destination.
   */
  PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
                         Array<PrimExpr> extents, int src_dst) const;

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
  /**
   * \brief Create a deep copy of this operator.
   *
   * Returns a TileOperator that is a copy of the current node, preserving all
   * configuration (buffers, parameters, and layout-related fields).
   * @return A TileOperator owning the cloned operator node.
   */

  /**
   * \brief Constructor.
   * \param args Expression arguments for the Conv2D im2col operator.
   * \param vmap Buffer variable mapping.
   */

  /**
   * \brief Get the TVM Op handle corresponding to this Conv2DIm2Col operator.
   * @return Reference to the singleton TVM Op representing this operator.
   */
438
439
  TileOperator Clone() const;
};
440

441
442
443
class Copy : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode);
444

445
446
447
448
449
450
  /*!
   * \brief Constructor.
   * \param args  Expression arguments for the copy.
   * \param vmap  Buffer variable mapping.
   */
  TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap);
451

452
453
454
455
  /*!
   * \brief Get the TVM Op handle corresponding to this Copy op.
   */
  static const Op &Get();
456
457
458
459
460
461
462
463
};

/*!
 * \brief Special operator for Conv2D im2col transformation.
 *
 * This operator converts input image layout into columnar format suitable
 * for matrix multiplication-based convolution lowering.
 */
464
class Conv2DIm2ColOpNode : public TileOperatorNode {
465
public:
466
467
468
469
470
471
472
473
474
475
476
  Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
  int stride;      // Stride for convolution
  int padding;     // Padding amount
  int dilation;    // Dilation factor
  int kernel;      // Kernel size
  int eviction_policy; // Cache eviction policy
  PrimExpr nhw_step;   // Step size in NHW dimensions
  PrimExpr c_step;     // Step size in channel dimension

  static constexpr const char *_type_key = "tl.Conv2DIm2Col";
  TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
477
478
479
480

  /*!
   * \brief Lower to TIR statement.
   */
481
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
482
483

  /*!
484
   * \brief Infer layout for this operator.
485
   */
486
  LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
487
488

  /*!
489
   * \brief Get TVM Op handle.
490
   */
491
492
493
  static const Op &Get();
  TileOperator Clone() const;
};
494

495
496
497
498
499
500
class Conv2DIm2ColOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator,
                                Conv2DIm2ColOpNode);
  TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
501
502
503
504
505
506
};

} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_COPY_H_