"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "8b651eee0c7878ac57f7b8ad5b2ee5936351ff86"
gemm.h 4.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 * \file tl/op/gemm.h
 * \brief Define gemm operator.
 *
 */

#ifndef TVM_TL_OP_GEMM_H_
#define TVM_TL_OP_GEMM_H_

10
#include "operator.h"
11
12

namespace tvm {
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/**
 * Check whether the target and configuration allow using WGMMA (wavefront-group
 * MMA) for this GEMM.
 *
 * @returns true if WGMMA can be used for the current node configuration and
 * target; false otherwise.
 */
/**
 * Lower this GEMM operator to a TVM Stmt for the given lowering context.
 *
 * @param T Lowering arguments and context (tile mappings, target, etc.).
 * @param analyzer Arithmetic analyzer used for symbolic simplification and
 * bounds reasoning.
 * @returns A lowered Stmt implementing the GEMM.
 */
/**
 * Infer memory/layout mapping for GEMM inputs/outputs at the given inference
 * level.
 *
 * @param T Layout inference inputs (buffers, shapes, constraints).
 * @param level Inference level that controls how aggressive/specific the
 * inferred layouts should be.
 * @returns A LayoutMap describing how logical tensor axes map to storage/layout
 * axes.
 */
/**
 * Create a deep/shallow copy of this TileOperator node as a TileOperator
 * reference.
 *
 * @returns A TileOperator reference that represents a clone of this GemmNode.
 */
/**
 * Determine the specific GEMM instruction variant to use for the given block
 * size and target.
 *
 * @param block_size The tile/block size (in elements or threads) used to select
 * instruction variant.
 * @param target The compilation target describing architecture and instruction
 * set.
 * @returns The GemmInst enum value representing the chosen GEMM instruction
 * family.
 */
/**
 * Compute how to partition work across warps for the given number of warps and
 * GEMM instruction.
 *
 * The returned pair is (warp_rows, warp_cols), describing the per-warp tiling
 * in row and column dimensions respectively.
 *
 * @param num_warps Total number of warps available for the block.
 * @param gemm_inst The GEMM instruction variant selected for the target.
 * @param target The compilation target which may constrain or influence
 * partitioning.
 * @returns A pair<int,int> = (warp_rows, warp_cols) describing the warp
 * partition.
 */
/**
 * Construct a Gemm operator handle from call arguments and a buffer mapping.
 *
 * @param args Array of call-time PrimExpr arguments passed to the operator.
 * @param vmap Mapping from buffer names/indices to tir::Buffer objects used by
 * this GEMM.
 */
/**
 * Obtain the registered Op descriptor for the GEMM operator.
 *
 * @returns A const reference to the Op representing "tl.Gemm".
 */
81
82
83
84
namespace tl {

using namespace tir;

85
86
87
88
89
enum class GemmWarpPolicy {
  kSquare = 0,
  kFullRow = 1,
  kFullCol = 2,
};
90

91
92
class GemmNode : public TileOperatorNode {
public:
93
  bool CheckWGMMA() const;
94
95
  Array<PrimExpr> call_args;
  tir::Buffer A, B, C;
96
97
  // pointer to the A, B, C
  PrimExpr Aptr, Bptr, Cptr;
98
99
  bool trans_A, trans_B;
  int M, N, K;
100
101
  int stride_A, stride_B;
  int offset_A, offset_B;
102
  bool clear_accum = false;
103
104
  // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
  // only will be enabled under cdna mfma instructions
105
  int kPack = 1;
106
  int wg_wait = 0;
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  GemmWarpPolicy policy;

  static constexpr const char *_type_key = "tl.Gemm";
  TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  TileOperator Clone() const;

private:
  // Target GEMM instruction
  enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
  GemmInst GetGemmInst(int block_size, Target target) const;

  std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
                                           Target target) const;

  mutable bool completed_ = false;
};

class Gemm : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode);
  TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
134
135
};

136
137
} // namespace tl
} // namespace tvm
138

139
#endif //  TVM_TL_OP_GEMM_H_