ir.cc 8.85 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/ir.cc
 * \brief Extension for the tvm script frontend.
 *
 */

7
#include <tvm/arith/analyzer.h>
8
9
10
11
12
#include <tvm/script/ir_builder/tir/ir.h>

namespace tvm {
namespace tl {

13
14
15
constexpr const char *tilelang_is_cpu_kernel_frame =
    "tilelang.is_cpu_kernel_frame";

16
17
using namespace script::ir_builder::tir;

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
  using namespace tvm::tir;
  using namespace tvm::script::ir_builder;
  IterVar iter_var(Range{nullptr}, Var(name, dtype),
                   tvm::tir::IterVarType::kThreadIndex, thread_tag);
  Var var = iter_var->var;
  if (Optional<PrimFuncFrame> opt_frame =
          IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
    opt_frame.value()->env_threads.Set(var, iter_var);
  } else {
    LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
  }
  return var;
}

33
34
static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
  using namespace tvm::tir;
35
  Var var = Var(name, dom->dtype);
36
37
38
39
40
41
42
43
44
45
46
47
48
  // Create a frame that represents a loop over the given domain.
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  n->vars.push_back(var);
  n->doms.push_back(Range(0, dom));
  n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms,
                          Stmt body) -> Stmt {
    ICHECK_EQ(vars.size(), 1);
    ICHECK_EQ(doms.size(), 1);
    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
  };
  return ForFrame(n);
}

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
ForFrame ParallelFor(Array<PrimExpr> extents,
                     Map<String, ObjectRef> annotations) {
  using namespace tvm::tir;
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  n->vars.reserve(extents.size());
  n->doms.reserve(extents.size());
  for (const auto &extent : extents) {
    DataType dtype = extent.dtype();
    n->vars.push_back(Var("v", extent.dtype()));
    n->doms.push_back(Range(make_const(dtype, 0), extent));
  }
  n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms,
                                     Stmt body) -> Stmt {
    ICHECK_EQ(vars.size(), doms.size());
    int n = vars.size();
    for (int i = n - 1; i >= 0; --i) {
      Range dom = doms[i];
      Var var = vars[i];
      body =
          For(var, dom->min, dom->extent, ForKind::kParallel, std::move(body),
              /*thread_binding=*/NullOpt, /*annotations=*/annotations);
    }
    return body;
  };
  return ForFrame(n);
}

ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages,
                      Array<PrimExpr> order, Array<PrimExpr> stages,
                      Array<Array<PrimExpr>> sync,
                      Array<Array<PrimExpr>> groups) {
  using namespace tvm::tir;
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  DataType dtype = stop.dtype();
  n->vars.push_back(Var("v", dtype));
  n->doms.push_back(Range(start, stop));
  n->f_make_for_loop = [=](Array<Var> vars, Array<Range> doms,
                           Stmt body) -> Stmt {
    ICHECK_EQ(vars.size(), doms.size());
    int n = vars.size();
    ICHECK(n == 1);
    Map<String, ObjectRef> anno;
    if (num_stages > 0)
      anno.Set("num_stages", PrimExpr(num_stages));
    if (order.size() > 0)
      anno.Set("tl_pipeline_order", order);
    if (stages.size() > 0)
      anno.Set("tl_pipeline_stage", stages);
    if (sync.size() > 0)
      anno.Set("tl_pipeline_sync", sync);
    if (groups.size() > 0)
      anno.Set("tl_pipeline_group", groups);
    body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial,
               std::move(body),
               /*thread_binding=*/NullOpt, /*annotations=*/anno);
    return body;
  };
  return ForFrame(n);
}

/*!
 * \brief A frame that represents a kernel launch.
 *
 * \sa KernelLaunchFrameNode
 */
class KernelLaunchFrameNode : public TIRFrameNode {
public:
  Array<TIRFrame> frames;

  void VisitAttrs(tvm::AttrVisitor *v) {
    TIRFrameNode::VisitAttrs(v);
    v->Visit("frames", &frames);
  }

  static constexpr const char *_type_key = "tl.KernelLaunchFrame";
  TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);

public:
  TVM_DLL void EnterWithScope() final {
    for (auto frame = frames.begin(); frame != frames.end(); ++frame)
      (*frame)->EnterWithScope();
  }
  /*!
   * \brief The method called when exiting RAII scope.
   * \sa tvm::support::With
   */
  TVM_DLL void ExitWithScope() final {
    for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
      (*frame)->ExitWithScope();
  }
};

/*!
 * \brief Managed reference to KernelLaunchFrameNode.
 *
 * \sa KernelLaunchFrameNode
 */
class KernelLaunchFrame : public TIRFrame {
public:
  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
                                                    KernelLaunchFrameNode);
};

KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
                               Array<PrimExpr> block_size,
                               Map<String, ObjectRef> attrs) {
  ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
156
157
158
159
160
161

  // If the kernel is a CPU kernel, we don't need to launch any threads.
  bool is_cpu_kernel_frame =
      attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);

  if (is_cpu_kernel_frame) {
162
    // Launch CPU Kernel
163
164
165
166
167
168
169
170
    ICHECK(grid_size.size() >= 0);
    ICHECK(block_size.size() == 0) << "CPU kernel cannot have block size";
    ICHECK(attrs.defined());
    // create grid loop var
    for (int i = 0; i < grid_size.size(); i++) {
      n->frames.push_back(
          MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i]));
    }
171
  } else {
172
173
174
    // Launch GPU Kernel
    ICHECK(grid_size.size() <= 3);
    if (grid_size.size() > 0)
175
176
177
      n->frames.push_back(LaunchThread(
          CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()),
          grid_size[0]));
178
    if (grid_size.size() > 1)
179
180
181
      n->frames.push_back(LaunchThread(
          CreateEnvThread("by", "blockIdx.y", grid_size[1].dtype()),
          grid_size[1]));
182
    if (grid_size.size() > 2)
183
184
185
      n->frames.push_back(LaunchThread(
          CreateEnvThread("bz", "blockIdx.z", grid_size[2].dtype()),
          grid_size[2]));
186
187
    if (block_size.defined()) {
      ICHECK(block_size.size() <= 3);
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
      if (block_size.size() > 0) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()),
            block_size[0]));
      }
      if (block_size.size() > 1) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("ty", "threadIdx.y", block_size[1].dtype()),
            block_size[1]));
      }
      if (block_size.size() > 2) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("tz", "threadIdx.z", block_size[2].dtype()),
            block_size[2]));
      }
203
    }
204
  }
205

206
  if (attrs.defined()) {
207
    auto empty_block = Block("root");
208
209
210
    empty_block->annotations = attrs;
    n->frames.push_back(empty_block);
  } else {
211
    n->frames.push_back(Block("root"));
212
213
214
215
216
217
218
219
220
221
222
  }

  return KernelLaunchFrame(n);
}

TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);

TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor);
TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor);
TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch);

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
  Array<TIRFrame> frames;

  void VisitAttrs(tvm::AttrVisitor *v) {
    TIRFrameNode::VisitAttrs(v);
    v->Visit("frames", &frames);
  }

  static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
  TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);

public:
  TVM_DLL void EnterWithScope() final {
    for (auto frame = frames.begin(); frame != frames.end(); ++frame)
      (*frame)->EnterWithScope();
  }
  /*!
   * \brief The method called when exiting RAII scope.
   * \sa tvm::support::With
   */
  TVM_DLL void ExitWithScope() final {
    for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
      (*frame)->ExitWithScope();
  }
};

class WarpSpecializeFrame : public TIRFrame {
public:
  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
                                                    TIRFrame,
                                                    WarpSpecializeFrameNode);
};

WarpSpecializeFrame WarpSpecialize(int warp_group_idx, PrimExpr thread_idx,
                                   int warp_group_size = 128) {
  ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
  PrimExpr min_bound =
      max(0, IntImm(thread_idx.dtype(), warp_group_idx) * warp_group_size);
  PrimExpr max_bound = min_bound + warp_group_size;
  PrimExpr condition = thread_idx >= min_bound && thread_idx < max_bound;
  IfFrame if_frame = If(condition);
  n->frames.push_back(if_frame);
  n->frames.push_back(Then());
  return WarpSpecializeFrame(n);
}

TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_REGISTER_GLOBAL("tl.WarpSpecialize").set_body_typed(WarpSpecialize);

273
274
} // namespace tl
} // namespace tvm