ir.cc 14 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/ir.cc
 * \brief Extension for the tvm script frontend.
 *
 */

7
#include "./transform/common/attr.h"
8
#include "op/builtin.h"
9
#include "tvm/ffi/any.h"
10
11
12
#include <tvm/ffi/object.h>

#include "support/ffi_aliases.h"
13
#include <tvm/arith/analyzer.h>
14
#include <tvm/ffi/reflection/registry.h>
15
16
#include <tvm/script/ir_builder/tir/ir.h>

17
18
#include <utility>

19
20
21
22
23
namespace tvm {
namespace tl {

using namespace script::ir_builder::tir;

24
25
26
static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
  using namespace tvm::tir;
  using namespace tvm::script::ir_builder;
27
28
  IterVar iter_var(Range{nullptr}, Var(std::move(name), dtype),
                   tvm::tir::IterVarType::kThreadIndex, std::move(thread_tag));
29
30
31
32
33
34
35
36
37
38
  Var var = iter_var->var;
  if (Optional<PrimFuncFrame> opt_frame =
          IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
    opt_frame.value()->env_threads.Set(var, iter_var);
  } else {
    LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
  }
  return var;
}

39
static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
40
  using namespace tvm::tir;
41
  Var var = Var(name, dom->dtype);
42
  // Create a frame that represents a loop over the given domain.
43
  ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
44
45
  n->vars.push_back(var);
  n->doms.push_back(Range(0, dom));
46
  n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
47
48
                          const Array<Optional<PrimExpr>> &steps,
                          Stmt body) -> Stmt {
49
50
    ICHECK_EQ(vars.size(), 1);
    ICHECK_EQ(doms.size(), 1);
51
52
53
54
55
56
    Optional<PrimExpr> step =
        !steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
               /*thread_binding=*/std::nullopt,
               /*annotations=*/tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any>{},
               /*step=*/step);
57
58
59
60
  };
  return ForFrame(n);
}

61
ForFrame ParallelFor(const Array<PrimExpr> &extents,
62
                     const Map<String, tvm::ffi::Any> &annotations) {
63
  using namespace tvm::tir;
64
  ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
65
66
67
68
69
70
71
  n->vars.reserve(extents.size());
  n->doms.reserve(extents.size());
  for (const auto &extent : extents) {
    DataType dtype = extent.dtype();
    n->vars.push_back(Var("v", extent.dtype()));
    n->doms.push_back(Range(make_const(dtype, 0), extent));
  }
72
73
74
  n->f_make_for_loop =
      [annotations](const Array<Var> &vars, const Array<Range> &doms,
                    const Array<Optional<PrimExpr>> &steps, Stmt body) -> Stmt {
75
76
77
78
79
    ICHECK_EQ(vars.size(), doms.size());
    int n = vars.size();
    for (int i = n - 1; i >= 0; --i) {
      Range dom = doms[i];
      Var var = vars[i];
80
81
      Optional<PrimExpr> step =
          i < steps.size() ? steps[i] : Optional<PrimExpr>(std::nullopt);
82
      body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
83
84
                 /*thread_binding=*/std::nullopt, /*annotations=*/annotations,
                 /*step=*/step);
85
86
87
88
89
90
    }
    return body;
  };
  return ForFrame(n);
}

91
92
93
94
95
ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
                      const Array<PrimExpr> &order,
                      const Array<PrimExpr> &stages,
                      const Array<Array<PrimExpr>> &sync,
                      const Array<Array<PrimExpr>> &groups) {
96
  using namespace tvm::tir;
97
  ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
98
99
  DataType dtype = stop.dtype();
  n->vars.push_back(Var("v", dtype));
100
101
  n->doms.push_back(Range(std::move(start), stop));
  n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
102
                           const Array<Optional<PrimExpr>> &steps,
103
104
105
106
                           Stmt body) -> Stmt {
    ICHECK_EQ(vars.size(), doms.size());
    int n = vars.size();
    ICHECK(n == 1);
107
    Map<String, tvm::ffi::Any> anno;
108
109
    if (num_stages > 0)
      anno.Set("num_stages", PrimExpr(num_stages));
110
    if (!order.empty())
111
      anno.Set("tl_pipeline_order", order);
112
    if (!stages.empty())
113
      anno.Set("tl_pipeline_stage", stages);
114
    if (!sync.empty())
115
      anno.Set("tl_pipeline_sync", sync);
116
    if (!groups.empty())
117
      anno.Set("tl_pipeline_group", groups);
118
119
    Optional<PrimExpr> step =
        !steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
120
    body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
121
122
               /*thread_binding=*/std::nullopt, /*annotations=*/anno,
               /*step=*/step);
123
124
125
126
127
    return body;
  };
  return ForFrame(n);
}

128
129
ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
                       const PrimExpr &index, PrimExpr group_size) {
130
  using namespace tvm::tir;
131
  ICHECK(!domain.empty());
132
  ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  n->vars.reserve(domain.size());
  n->doms.reserve(domain.size());
  PrimExpr domain_size = domain[0];
  for (int i = 1; i < domain.size(); i++) {
    domain_size *= domain[i];
  }

  auto waves = ceildiv(domain_size, wave_size);
  auto loop_var = Var("w", waves.dtype());
  group_size = min(group_size, domain[domain.size() - 1]);
  Array<Var> coord_vars;

  for (int i = 0; i < domain.size(); ++i) {
    DataType dtype = domain[i].dtype();
    Var coord("v" + std::to_string(i), dtype);
    coord_vars.push_back(coord);
    n->vars.push_back(coord);
    n->doms.push_back(Range(make_const(dtype, 0), domain[i]));
  }

  Array<PrimExpr> grouped_domain;
  grouped_domain.push_back(truncdiv(domain[domain.size() - 1], group_size));
  for (int i = 0; i < domain.size() - 1; ++i) {
    grouped_domain.push_back(domain[i]);
  }
  grouped_domain.push_back(group_size);

160
  n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
161
162
                           const Array<Optional<PrimExpr>> &steps,
                           Stmt body) -> Stmt {
163
    ICHECK_EQ(vars.size(), doms.size());
164
    Map<String, tvm::ffi::Any> anno;
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
    PrimExpr rem = loop_var * wave_size + index;

    for (int i = grouped_domain.size() - 1; i >= 1; --i) {
      idxs.Set(i, truncmod(rem, grouped_domain[i]));
      rem = truncdiv(rem, grouped_domain[i]);
    }
    idxs.Set(0, rem);

    auto out_if = tvm::tir::IfThenElse(
        domain_size <= (loop_var * wave_size + index),
        tvm::tir::Evaluate(
            tvm::tir::Call(DataType::Handle(), tvm::tl::loop_break(), {})),
        Stmt());

180
181
182
183
184
    arith::Analyzer analyzer;
    Stmt new_body = body;
    if (analyzer.CanProveGreaterEqual(waves, 2)) {
      new_body = SeqStmt({out_if, body});
    }
185
186
187
188
189
    Optional<PrimExpr> step =
        !steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
    Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body,
                     /*thread_binding=*/std::nullopt, /*annotations=*/anno,
                     /*step=*/step);
190
191
192
193
194
195
196
197
198
199
200
    for (int i = 0; i < vars.size() - 1; ++i) {
      outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
    }
    outer = tvm::tir::LetStmt(vars[vars.size() - 1],
                              idxs[0] * group_size + idxs[vars.size()], outer);
    return outer;
  };

  return ForFrame(n);
}

201
202
203
204
205
206
207
208
209
/*!
 * \brief A frame that represents a kernel launch.
 *
 * \sa KernelLaunchFrameNode
 */
class KernelLaunchFrameNode : public TIRFrameNode {
public:
  Array<TIRFrame> frames;

210
211
212
213
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<KernelLaunchFrameNode>().def_ro(
        "frames", &KernelLaunchFrameNode::frames);
214
215
  }

216
217
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
                                    KernelLaunchFrameNode, TIRFrameNode);
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

public:
  TVM_DLL void EnterWithScope() final {
    for (auto frame = frames.begin(); frame != frames.end(); ++frame)
      (*frame)->EnterWithScope();
  }
  /*!
   * \brief The method called when exiting RAII scope.
   * \sa tvm::support::With
   */
  TVM_DLL void ExitWithScope() final {
    for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
      (*frame)->ExitWithScope();
  }
};

/*!
 * \brief Managed reference to KernelLaunchFrameNode.
 *
 * \sa KernelLaunchFrameNode
 */
class KernelLaunchFrame : public TIRFrame {
public:
241
242
243
244
245
246
247
  explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
      : TIRFrame(::tvm::ffi::UnsafeInit{}) {
    ICHECK(data != nullptr);
    data_ = std::move(data);
  }
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
                                                KernelLaunchFrameNode);
248
249
};

250
251
252
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
                               const Optional<Array<PrimExpr>> &block_size_opt,
                               const Map<String, ffi::Any> &attrs) {
253
254
  ObjectPtr<KernelLaunchFrameNode> n =
      tvm::ffi::make_object<KernelLaunchFrameNode>();
255
256
257
258
259

  // If the kernel is a CPU kernel, we don't need to launch any threads.
  bool is_cpu_kernel_frame =
      attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);

260
261
  auto block_size = block_size_opt.value_or(Array<PrimExpr>());

262
  if (is_cpu_kernel_frame) {
263
    // Launch CPU Kernel
264
    ICHECK(grid_size.size() >= 0);
265
    ICHECK(block_size.empty()) << "CPU kernel cannot have block size";
266
267
268
269
270
271
    ICHECK(attrs.defined());
    // create grid loop var
    for (int i = 0; i < grid_size.size(); i++) {
      n->frames.push_back(
          MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i]));
    }
272
  } else {
273
274
    // Launch GPU Kernel
    ICHECK(grid_size.size() <= 3);
275
    if (!grid_size.empty())
276
277
278
      n->frames.push_back(LaunchThread(
          CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()),
          grid_size[0]));
279
    if (grid_size.size() > 1)
280
281
282
      n->frames.push_back(LaunchThread(
          CreateEnvThread("by", "blockIdx.y", grid_size[1].dtype()),
          grid_size[1]));
283
    if (grid_size.size() > 2)
284
285
286
      n->frames.push_back(LaunchThread(
          CreateEnvThread("bz", "blockIdx.z", grid_size[2].dtype()),
          grid_size[2]));
287
288
    if (block_size.defined()) {
      ICHECK(block_size.size() <= 3);
289
      if (!block_size.empty()) {
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        n->frames.push_back(LaunchThread(
            CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()),
            block_size[0]));
      }
      if (block_size.size() > 1) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("ty", "threadIdx.y", block_size[1].dtype()),
            block_size[1]));
      }
      if (block_size.size() > 2) {
        n->frames.push_back(LaunchThread(
            CreateEnvThread("tz", "threadIdx.z", block_size[2].dtype()),
            block_size[2]));
      }
304
    }
305
  }
306

307
  if (attrs.defined()) {
308
    auto empty_block = tvm::script::ir_builder::tir::Block(MainBlockName);
309
310
311
    empty_block->annotations = attrs;
    n->frames.push_back(empty_block);
  } else {
312
    n->frames.push_back(tvm::script::ir_builder::tir::Block(MainBlockName));
313
314
315
316
317
  }

  return KernelLaunchFrame(n);
}

318
TVM_FFI_STATIC_INIT_BLOCK() {
319
320
321
322
323
324
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def("tl.Parallel", ParallelFor)
      .def("tl.Pipelined", PipelinedFor)
      .def("tl.Persistent", PersistentFor)
      .def("tl.KernelLaunch", KernelLaunch);
325
}
326

327
328
329
330
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
  Array<TIRFrame> frames;

331
332
333
334
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<WarpSpecializeFrameNode>().def_ro(
        "frames", &WarpSpecializeFrameNode::frames);
335
336
  }

337
338
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
                                    WarpSpecializeFrameNode, TIRFrameNode);
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

public:
  TVM_DLL void EnterWithScope() final {
    for (auto frame = frames.begin(); frame != frames.end(); ++frame)
      (*frame)->EnterWithScope();
  }
  /*!
   * \brief The method called when exiting RAII scope.
   * \sa tvm::support::With
   */
  TVM_DLL void ExitWithScope() final {
    for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
      (*frame)->ExitWithScope();
  }
};

class WarpSpecializeFrame : public TIRFrame {
public:
357
358
359
360
361
362
363
  explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
      : TIRFrame(::tvm::ffi::UnsafeInit{}) {
    ICHECK(data != nullptr);
    data_ = std::move(data);
  }
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
                                                WarpSpecializeFrameNode);
364
365
};

366
367
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
                                   const PrimExpr &thread_idx,
368
                                   int warp_group_size = 128) {
369
370
  ObjectPtr<WarpSpecializeFrameNode> n =
      tvm::ffi::make_object<WarpSpecializeFrameNode>();
371
372
  PrimExpr condition;
  std::vector<int> warp_groups;
373
  warp_groups.reserve(warp_group_ids.size());
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
  for (int i = 0; i < warp_group_ids.size(); i++) {
    warp_groups.push_back(Downcast<IntImm>(warp_group_ids[i])->value);
  }
  std::sort(warp_groups.begin(), warp_groups.end());

  // Merge consecutive groups
  std::vector<std::pair<int, int>> merged;
  for (int group : warp_groups) {
    if (merged.empty() || group != merged.back().second) {
      merged.emplace_back(group, group + 1);
    } else {
      merged.back().second = group + 1;
    }
  }

  for (const auto &[start, end] : merged) {
    PrimExpr min_bound = IntImm(thread_idx.dtype(), start) * warp_group_size;
    PrimExpr max_bound = IntImm(thread_idx.dtype(), end) * warp_group_size;
    PrimExpr range_cond = (thread_idx >= min_bound) && (thread_idx < max_bound);

    if (condition.defined()) {
      condition = tir::Or(condition, range_cond);
    } else {
      condition = range_cond;
    }
  }
400
  IfFrame if_frame = If(condition);
401
  AttrFrame attr_frame = Attr(Integer(0), "warp_specialize", Integer(1));
402
403
  n->frames.push_back(if_frame);
  n->frames.push_back(Then());
404
  n->frames.push_back(attr_frame);
405
406
407
  return WarpSpecializeFrame(n);
}

408
TVM_FFI_STATIC_INIT_BLOCK() {
409
410
411
412
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
  KernelLaunchFrameNode::RegisterReflection();
  WarpSpecializeFrameNode::RegisterReflection();
413
}
414

415
416
} // namespace tl
} // namespace tvm