ir.cc 12.8 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/ir.cc
 * \brief Extension for the tvm script frontend.
 *
 */

7
#include "./transform/common/attr.h"
8
#include "op/builtin.h"
9
#include "tvm/ffi/any.h"
10
#include <tvm/arith/analyzer.h>
11
#include <tvm/ffi/reflection/registry.h>
12
13
#include <tvm/script/ir_builder/tir/ir.h>

14
15
#include <utility>

16
17
18
19
20
namespace tvm {
namespace tl {

using namespace script::ir_builder::tir;

21
22
23
static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
  using namespace tvm::tir;
  using namespace tvm::script::ir_builder;
24
25
  IterVar iter_var(Range{nullptr}, Var(std::move(name), dtype),
                   tvm::tir::IterVarType::kThreadIndex, std::move(thread_tag));
26
27
28
29
30
31
32
33
34
35
  Var var = iter_var->var;
  if (Optional<PrimFuncFrame> opt_frame =
          IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
    opt_frame.value()->env_threads.Set(var, iter_var);
  } else {
    LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
  }
  return var;
}

36
static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
37
  using namespace tvm::tir;
38
  Var var = Var(name, dom->dtype);
39
40
41
42
  // Create a frame that represents a loop over the given domain.
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  n->vars.push_back(var);
  n->doms.push_back(Range(0, dom));
43
44
  n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
                          const Stmt &body) -> Stmt {
45
46
47
48
49
50
51
    ICHECK_EQ(vars.size(), 1);
    ICHECK_EQ(doms.size(), 1);
    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
  };
  return ForFrame(n);
}

52
53
ForFrame ParallelFor(const Array<PrimExpr> &extents,
                     const Map<String, ObjectRef> &annotations) {
54
55
56
57
58
59
60
61
62
  using namespace tvm::tir;
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  n->vars.reserve(extents.size());
  n->doms.reserve(extents.size());
  for (const auto &extent : extents) {
    DataType dtype = extent.dtype();
    n->vars.push_back(Var("v", extent.dtype()));
    n->doms.push_back(Range(make_const(dtype, 0), extent));
  }
63
64
  n->f_make_for_loop = [annotations](const Array<Var> &vars,
                                     const Array<Range> &doms,
65
66
67
68
69
70
                                     Stmt body) -> Stmt {
    ICHECK_EQ(vars.size(), doms.size());
    int n = vars.size();
    for (int i = n - 1; i >= 0; --i) {
      Range dom = doms[i];
      Var var = vars[i];
71
72
      body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
                 /*thread_binding=*/std::nullopt, /*annotations=*/annotations);
73
74
75
76
77
78
    }
    return body;
  };
  return ForFrame(n);
}

79
80
81
82
83
ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
                      const Array<PrimExpr> &order,
                      const Array<PrimExpr> &stages,
                      const Array<Array<PrimExpr>> &sync,
                      const Array<Array<PrimExpr>> &groups) {
84
85
86
87
  using namespace tvm::tir;
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  DataType dtype = stop.dtype();
  n->vars.push_back(Var("v", dtype));
88
89
  n->doms.push_back(Range(std::move(start), stop));
  n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
90
91
92
93
94
95
96
                           Stmt body) -> Stmt {
    ICHECK_EQ(vars.size(), doms.size());
    int n = vars.size();
    ICHECK(n == 1);
    Map<String, ObjectRef> anno;
    if (num_stages > 0)
      anno.Set("num_stages", PrimExpr(num_stages));
97
    if (!order.empty())
98
      anno.Set("tl_pipeline_order", order);
99
    if (!stages.empty())
100
      anno.Set("tl_pipeline_stage", stages);
101
    if (!sync.empty())
102
      anno.Set("tl_pipeline_sync", sync);
103
    if (!groups.empty())
104
      anno.Set("tl_pipeline_group", groups);
105
    body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
106
               /*thread_binding=*/std::nullopt, /*annotations=*/anno);
107
108
109
110
111
    return body;
  };
  return ForFrame(n);
}

112
113
ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
                       const PrimExpr &index, PrimExpr group_size) {
114
  using namespace tvm::tir;
115
  ICHECK(!domain.empty());
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
  n->vars.reserve(domain.size());
  n->doms.reserve(domain.size());
  PrimExpr domain_size = domain[0];
  for (int i = 1; i < domain.size(); i++) {
    domain_size *= domain[i];
  }

  auto waves = ceildiv(domain_size, wave_size);
  auto loop_var = Var("w", waves.dtype());
  group_size = min(group_size, domain[domain.size() - 1]);
  Array<Var> coord_vars;

  for (int i = 0; i < domain.size(); ++i) {
    DataType dtype = domain[i].dtype();
    Var coord("v" + std::to_string(i), dtype);
    coord_vars.push_back(coord);
    n->vars.push_back(coord);
    n->doms.push_back(Range(make_const(dtype, 0), domain[i]));
  }

  Array<PrimExpr> grouped_domain;
  grouped_domain.push_back(truncdiv(domain[domain.size() - 1], group_size));
  for (int i = 0; i < domain.size() - 1; ++i) {
    grouped_domain.push_back(domain[i]);
  }
  grouped_domain.push_back(group_size);

144
145
  n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
                           const Stmt &body) -> Stmt {
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    ICHECK_EQ(vars.size(), doms.size());
    Map<String, ObjectRef> anno;
    Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
    PrimExpr rem = loop_var * wave_size + index;

    for (int i = grouped_domain.size() - 1; i >= 1; --i) {
      idxs.Set(i, truncmod(rem, grouped_domain[i]));
      rem = truncdiv(rem, grouped_domain[i]);
    }
    idxs.Set(0, rem);

    auto out_if = tvm::tir::IfThenElse(
        domain_size <= (loop_var * wave_size + index),
        tvm::tir::Evaluate(
            tvm::tir::Call(DataType::Handle(), tvm::tl::loop_break(), {})),
        Stmt());

163
164
165
166
167
168
169
    arith::Analyzer analyzer;
    Stmt new_body = body;
    if (analyzer.CanProveGreaterEqual(waves, 2)) {
      new_body = SeqStmt({out_if, body});
    }
    Stmt outer =
        For(loop_var, 0, waves, ForKind::kSerial, new_body, std::nullopt, anno);
170
171
172
173
174
175
176
177
178
179
180
    for (int i = 0; i < vars.size() - 1; ++i) {
      outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
    }
    outer = tvm::tir::LetStmt(vars[vars.size() - 1],
                              idxs[0] * group_size + idxs[vars.size()], outer);
    return outer;
  };

  return ForFrame(n);
}

181
182
183
184
185
186
187
188
189
/*!
 * \brief A frame that represents a kernel launch.
 *
 * \sa KernelLaunchFrameNode
 */
class KernelLaunchFrameNode : public TIRFrameNode {
public:
  Array<TIRFrame> frames;

190
191
192
193
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<KernelLaunchFrameNode>().def_ro(
        "frames", &KernelLaunchFrameNode::frames);
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  }

  static constexpr const char *_type_key = "tl.KernelLaunchFrame";
  TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);

public:
  TVM_DLL void EnterWithScope() final {
    for (auto frame = frames.begin(); frame != frames.end(); ++frame)
      (*frame)->EnterWithScope();
  }
  /*!
   * \brief The method called when exiting RAII scope.
   * \sa tvm::support::With
   */
  TVM_DLL void ExitWithScope() final {
    for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
      (*frame)->ExitWithScope();
  }
};

/*!
 * \brief Managed reference to KernelLaunchFrameNode.
 *
 * \sa KernelLaunchFrameNode
 */
class KernelLaunchFrame : public TIRFrame {
public:
  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
                                                    KernelLaunchFrameNode);
};

225
226
227
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
                               const Optional<Array<PrimExpr>> &block_size_opt,
                               const Map<String, ffi::Any> &attrs) {
228
  ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
229
230
231
232
233

  // If the kernel is a CPU kernel, we don't need to launch any threads.
  bool is_cpu_kernel_frame =
      attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);

234
235
  auto block_size = block_size_opt.value_or(Array<PrimExpr>());

236
  if (is_cpu_kernel_frame) {
237
    // Launch CPU Kernel
238
    ICHECK(grid_size.size() >= 0);
239
    ICHECK(block_size.empty()) << "CPU kernel cannot have block size";
240
241
242
243
244
245
    ICHECK(attrs.defined());
    // create grid loop var
    for (int i = 0; i < grid_size.size(); i++) {
      n->frames.push_back(
          MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i]));
    }
246
  } else {
247
248
    // Launch GPU Kernel
    ICHECK(grid_size.size() <= 3);
249
    if (!grid_size.empty())
250
251
252
      n->frames.push_back(LaunchThread(
          CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()),
          grid_size[0]));
253
    if (grid_size.size() > 1)
254
255
256
      n->frames.push_back(LaunchThread(
          CreateEnvThread("by", "blockIdx.y", grid_size[1].dtype()),
          grid_size[1]));
257
    if (grid_size.size() > 2)
258
259
260
      n->frames.push_back(LaunchThread(
          CreateEnvThread("bz", "blockIdx.z", grid_size[2].dtype()),
          grid_size[2]));
261
262
    if (block_size.defined()) {
      ICHECK(block_size.size() <= 3);
263
      if (!block_size.empty()) {
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        n->frames.push_back(LaunchThread(
            CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()),
            block_size[0]));
      }
      if (block_size.size() > 1) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("ty", "threadIdx.y", block_size[1].dtype()),
            block_size[1]));
      }
      if (block_size.size() > 2) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("tz", "threadIdx.z", block_size[2].dtype()),
            block_size[2]));
      }
278
    }
279
  }
280

281
  if (attrs.defined()) {
282
    auto empty_block = tvm::script::ir_builder::tir::Block(MainBlockName);
283
284
285
    empty_block->annotations = attrs;
    n->frames.push_back(empty_block);
  } else {
286
    n->frames.push_back(tvm::script::ir_builder::tir::Block(MainBlockName));
287
288
289
290
291
292
293
  }

  return KernelLaunchFrame(n);
}

TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);

294
295
296
297
298
299
300
301
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def("tl.Parallel", ParallelFor)
      .def("tl.Pipelined", PipelinedFor)
      .def("tl.Persistent", PersistentFor)
      .def("tl.KernelLaunch", KernelLaunch);
});
302

303
304
305
306
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
  Array<TIRFrame> frames;

307
308
309
310
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<WarpSpecializeFrameNode>().def_ro(
        "frames", &WarpSpecializeFrameNode::frames);
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
  }

  static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
  TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);

public:
  TVM_DLL void EnterWithScope() final {
    for (auto frame = frames.begin(); frame != frames.end(); ++frame)
      (*frame)->EnterWithScope();
  }
  /*!
   * \brief The method called when exiting RAII scope.
   * \sa tvm::support::With
   */
  TVM_DLL void ExitWithScope() final {
    for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
      (*frame)->ExitWithScope();
  }
};

class WarpSpecializeFrame : public TIRFrame {
public:
  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
                                                    TIRFrame,
                                                    WarpSpecializeFrameNode);
};

338
339
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
                                   const PrimExpr &thread_idx,
340
341
                                   int warp_group_size = 128) {
  ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
342
343
  PrimExpr condition;
  std::vector<int> warp_groups;
344
  warp_groups.reserve(warp_group_ids.size());
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
  for (int i = 0; i < warp_group_ids.size(); i++) {
    warp_groups.push_back(Downcast<IntImm>(warp_group_ids[i])->value);
  }
  std::sort(warp_groups.begin(), warp_groups.end());

  // Merge consecutive groups
  std::vector<std::pair<int, int>> merged;
  for (int group : warp_groups) {
    if (merged.empty() || group != merged.back().second) {
      merged.emplace_back(group, group + 1);
    } else {
      merged.back().second = group + 1;
    }
  }

  for (const auto &[start, end] : merged) {
    PrimExpr min_bound = IntImm(thread_idx.dtype(), start) * warp_group_size;
    PrimExpr max_bound = IntImm(thread_idx.dtype(), end) * warp_group_size;
    PrimExpr range_cond = (thread_idx >= min_bound) && (thread_idx < max_bound);

    if (condition.defined()) {
      condition = tir::Or(condition, range_cond);
    } else {
      condition = range_cond;
    }
  }
371
  IfFrame if_frame = If(condition);
372
  AttrFrame attr_frame = Attr(Integer(0), "warp_specialize", Integer(1));
373
374
  n->frames.push_back(if_frame);
  n->frames.push_back(Then());
375
  n->frames.push_back(attr_frame);
376
377
378
379
  return WarpSpecializeFrame(n);
}

TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
380
381
382
383
384
385
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
  KernelLaunchFrameNode::RegisterReflection();
  WarpSpecializeFrameNode::RegisterReflection();
});
386

387
388
} // namespace tl
} // namespace tvm