elem.cc 17.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/elem.cc
 *
 * Define elment-wise operators.
 */

#include "elem.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
14
#include "../transform/common/loop_fusion_utils.h"
15
#include "../transform/common/loop_parallel_transform_utils.h"
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
    rgs[i] = region.GetRanges();
    bf[i] = region.GetBuffer();
  }
  std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
  std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
38
  if (args.size() >= 3) {
39
40
41
42
43
44
45
46
    auto coalesced_width = Downcast<IntImm>(args[2]);
    if (coalesced_width->value > 0) {
      this->coalesced_width = coalesced_width;
    }
  }
  if (args.size() >= 4) {
    auto disable_tma = Downcast<Bool>(args[3]);
    this->disable_tma = disable_tma;
47
  }
48
49
50
  if (args.size() >= 5) {
    this->eviction_policy = args[4].as<IntImmNode>()->value;
  }
51
52
53
54
55
56
}

Array<IterVar> Copy::MakeIterVars() const {
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
57
58
    if (is_one(src_range[i]->extent))
      continue;
59
    Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
60
    idx++;
61
62
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
63
64
65
66
67
68
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
69
70
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
                                  int src_dst) const {
71
72
73
74
75
76
77
78
79
80
81
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
82
83
84
  ICHECK(idx == ivs.size())
      << "idx = " << idx << ", ivs.size() = " << ivs.size()
      << "src name = " << src->name << ", dst name = " << dst->name;
85
86
87
  return indices;
}

88
89
90
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
                             const Array<IterVar> &ivs, Array<PrimExpr> extents,
                             int src_dst) const {
91
92
93
94
95
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
96
97
    if (is_one(ranges[i]->extent))
      continue;
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
112
113
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
114
115
116
117
    return cond;
  }
}

118
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
119
  Array<IterVar> loop_vars = MakeIterVars();
120
121
122
123
124
125
  bool is_scalar = loop_vars.size() == 0;
  if (is_scalar) {
    return For(Var("i"), 0, 1, ForKind::kSerial,
               BufferStore(dst, BufferLoad(src, {0}), {0}));
  }

126
127
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
128

129
130
131
132
133
134
135
136
137
138
  ICHECK(loop_vars.size() <= src_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", src_range.size() = " << src_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  ICHECK(loop_vars.size() <= dst_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

139
140
141
142
143
144
145
  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  PrimExpr value = BufferLoad(src, src_indices);
146
147
148
149
  if (src->dtype != dst->dtype)
    value = Cast(dst->dtype, value);
  if (src_predicate.defined())
    value = if_then_else(src_predicate, value, make_zero(dst->dtype));
150
151

  Stmt body = BufferStore(dst, value, dst_indices);
152
153
  if (dst_predicate.defined())
    body = IfThenElse(dst_predicate, body);
154
155
  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
156
    if (coalesced_width.defined()) {
157
158
      annotations.Set("coalesced_width", coalesced_width);
    }
159
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
160
               ForKind::kParallel, body, std::nullopt, annotations);
161
162
163
164
  }
  return Downcast<For>(body);
}

165
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
166
167
  Target target = T.target;
  bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
168
  Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
169
170
  if (ldsm_stmt.defined())
    return ldsm_stmt;
171

172
173
174
175
176
177
  if (!disable_tma) {
    Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
    if (bulk_copy_stmt.defined())
      return bulk_copy_stmt;
  }

178
179
180
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));

181
182
183
  auto transformed_loop =
      Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));

184
  For vectorized_thread_loop;
185
  auto par_op = std::make_unique<ParallelOp>(transformed_loop);
186
187

  if (is_cpu_target) {
188
    vectorized_thread_loop = VectorizeLoop(transformed_loop);
189
  } else {
190
191
192
193
194
195
    std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
                                      InferLevel::kFree};
    for (auto level : levels) {
      par_op->InferLayout(
          {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
    }
196
197
198
199
    auto loop_layout = par_op->GetLoopLayout();
    auto thread_var = T.thread_var;
    auto thread_loop =
        PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
200
201
202
    vectorized_thread_loop = VectorizeLoop(thread_loop);
  }

203
  if (par_op->GetPredicate(T.thread_var).defined()) {
204
205
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
206
207
208
209
  }
  return vectorized_thread_loop;
}

210
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  // Check buffer scope
  bool is_ldmatrix;
  if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
      dst.scope() == "local.fragment") {
    is_ldmatrix = true;
  } else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
             src.scope() == "local.fragment") {
    is_ldmatrix = false;
  } else {
    return Stmt();
  }

  // Check no predicates
  Array<IterVar> loop_vars = MakeIterVars();
225
226
227
228
  if (loop_vars.size() < 2)
    return Stmt();
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
229
230
  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
231
232
  if (src_predicate.defined() || dst_predicate.defined())
    return Stmt();
233
234
235
236
237
238

  Buffer shared_tensor = is_ldmatrix ? src : dst;
  Buffer local_tensor = is_ldmatrix ? dst : src;

  Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
  Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
239
240
  Array<PrimExpr> local_indices_transformed =
      local_layout->Forward(local_indices);
241
242
  local_tensor = T.buffer_remap[local_tensor];
  // currently only support 1-d case
243
244
  if (local_layout->OutputDim() != 1)
    return Stmt();
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

  Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
  Array<PrimExpr> shared_indices_transformed = shared_indices;
  Layout shared_layout;
  if (T.buffer_remap.count(shared_tensor)) {
    shared_layout = T.layout_map[shared_tensor];
    shared_tensor = T.buffer_remap[shared_tensor];
    shared_indices_transformed = shared_layout->Forward(shared_indices);
  }

  // Check local_layout follows 8x8 layout
  bool is_transposed;
  IterVar col_var = loop_vars[loop_vars.size() - 1];
  IterVar row_var = loop_vars[loop_vars.size() - 2];
  PrimExpr local_layout_thread_map =
260
      FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32);
261
  PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
262
      {FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
263
264
  PrimExpr matrix_8x8_thread_map_trans =
      makeGemmFragment8x8Transposed()->ForwardThread(
265
          {FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
266
267
  PrimExpr local_indices_flattened =
      local_tensor.OffsetOf(local_indices_transformed).back();
268
  if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
269
270
      IndiceCanVectorize(local_indices_flattened, col_var->var,
                         col_var->dom->extent, 2, analyzer)) {
271
    is_transposed = false;
272
273
274
275
  } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
                                     local_layout_thread_map) &&
             IndiceCanVectorize(local_indices_flattened, row_var->var,
                                row_var->dom->extent, 2, analyzer)) {
276
277
278
279
280
    is_transposed = true;
  } else {
    return Stmt();
  }
  // Check shared_layout is 16 bytes continuous
281
282
283
284
285
286
  if (shared_tensor->dtype.bytes() != 2)
    return Stmt();
  PrimExpr flattened_indice =
      shared_tensor.OffsetOf(shared_indices_transformed).back();
  if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
                          loop_vars.back()->dom->extent, 8, analyzer))
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    return Stmt();

  // Can only support local_range to be a full range
  for (size_t i = 0; i < dst_range.size(); i++) {
    if (!is_zero(dst_range[i]->min) ||
        !analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
      return Stmt();
  }

  // Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
  PrimExpr extent = local_tensor->shape[0];
  int num = 1;
  if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
    num = 4;
  else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
    num = 2;

  Array<PrimExpr> args;
305
  const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
306
307
308
309
310
311
312
  args.push_back(static_cast<int>(is_transposed));
  args.push_back(num);

  // Create shared address with regard to local address
  // if not transpose
  // coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
  // if transpose
313
314
  // coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
  // % 8 / 2)
315
316
317
318
319
  Var local_iter("i");
  Layout inv = local_layout->Inverse();
  Array<PrimExpr> shared_coords;
  PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
  if (!is_transposed)
320
321
322
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
         warp + FloorMod(T.thread_var, 8) * 4});
323
  else
324
325
326
327
328
329
330
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
             FloorMod(T.thread_var, 2),
         warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
  shared_coords.pop_back(); // remove rep
  if (shared_layout.defined())
    shared_coords = shared_layout->Forward(shared_coords);
331
  PrimExpr shared_addr = shared_tensor.access_ptr(
332
333
      is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
      shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
334
335
336
337
  args.push_back(shared_addr);

  if (is_ldmatrix) {
    // Can only support same dtype for ldmatrx
338
339
340
341
    if (local_tensor->dtype != shared_tensor->dtype)
      return Stmt();
    PrimExpr local_addr = local_tensor.access_ptr(
        2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
342
343
344
    args.push_back(local_addr);
  } else {
    for (int i = 0; i < num; i++) {
345
346
347
348
      PrimExpr value0 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
      PrimExpr value1 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
349
350
351
352
      if (local_tensor->dtype != shared_tensor->dtype) {
        value0 = Cast(shared_tensor->dtype, value0);
        value1 = Cast(shared_tensor->dtype, value1);
      }
353
      PrimExpr value_packed =
354
          Call(DataType::Int(32), pack_b16(), {value0, value1});
355
356
357
358
359
      args.push_back(value_packed);
    }
  }

  auto body = Evaluate(Call(DataType::Handle(), op, args));
360
361
  For for_node =
      For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
362
  for_node = LoopPragmaUnroll(for_node);
363
364
365
366
367
368
369
  auto range = T.thread_bounds;
  if (range.defined()) {
    auto thread_var = T.thread_var;
    auto thread_var_with_offset = thread_var - range->min;
    for_node.CopyOnWrite()->body =
        Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
  }
370
371
372
  return for_node;
}

373
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
374
375
376
377
378
379
380
381
382
  // Use parallel op to infer the layout
  if (par_op_ == nullptr) {
    arith::Analyzer analyzer;
    par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
  }
  return par_op_->InferLayout(T, level);
}

Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

  if (args[0]->IsInstance<BufferLoadNode>()) {
    auto buffer_load = Downcast<BufferLoad>(args[0]);
    for (const auto &index : buffer_load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        CHECK(ramp->stride.as<IntImmNode>()->value == 1)
            << "Only stride 1 ramps are supported";
        const auto *lanes = ramp->lanes.as<IntImmNode>();
        CHECK(lanes)
            << "Scalable vectors not supported in BufferRegion conversion";
        region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
      } else {
        region.push_back(Range::FromMinExtent(index, 1));
      }
    }
    dst = buffer_load->buffer;
  } else {
    dst = vmap[GetVarFromAccessPtr(args[0])];
    for (int i = 0; i < dst->shape.size(); i++) {
      region.push_back(Range(0, dst->shape[i]));
    }
  }

406
407
408
409
410
  if (args[1]->dtype != dst->dtype) {
    value = Cast(dst->dtype, args[1]);
  } else {
    value = args[1];
  }
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

  ICHECK(region.size() == dst->shape.size())
      << "region size = " << region.size() << " != " << dst->shape.size();
  for (int i = 0; i < region.size(); i++) {
    // bound check if region is static
    if (region[i]->min.as<IntImm>()) {
      int64_t min = Downcast<IntImm>(region[i]->min)->value;
      ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
    }
    if (region[i]->extent.as<IntImm>()) {
      int64_t extent = Downcast<IntImm>(region[i]->extent)->value;
      ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value)
          << "region[" << i << "] = " << extent << " > " << dst->shape[i];
    }
  }
426
427
}

428
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
429
430
431
432
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
433
    Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
434
    loop_vars.push_back({region[i], var, IterVarType::kDataPar});
435
436
437
438
    dst_indices.push_back(var);
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
439
440
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
441
442
443
444
  }
  return Downcast<For>(body);
}

445
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
446
447
448

  if (dst.scope() == "local.fragment") {
    auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
449
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
450
                        InferLevel::kFree);
451
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
452
453
454
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
455
456
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    if (par_op->GetPredicate(T.thread_var).defined()) {
457
458
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
459
460
461
462
463
464
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
    auto vectorized_thread_loop = VectorizeLoop(init_loop);
    return vectorized_thread_loop;
465
466
  } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
    auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
467
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
468
469
470
471
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
472
473
474
475
    if (par_op->GetPredicate(T.thread_var).defined()) {
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
    }
476
    return vectorized_thread_loop;
477
478
479
480
481
482
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
  }
}

TIR_REGISTER_TL_OP(Copy, copy)
483
    .set_num_inputs(4)
484
485
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
486
487
488

TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
489
490
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
491

492
493
} // namespace tl
} // namespace tvm