elem.cc 17.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/elem.cc
 *
 * Define elment-wise operators.
 */

#include "elem.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
14
#include "../transform/common/loop_fusion_utils.h"
15
#include "../transform/common/loop_parallel_transform_utils.h"
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
    rgs[i] = region.GetRanges();
    bf[i] = region.GetBuffer();
  }
  std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
  std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
38
  if (args.size() >= 3) {
39
40
41
42
43
44
45
46
    coalesced_width = Downcast<IntImm>(args[2]);
  }
}

Array<IterVar> Copy::MakeIterVars() const {
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
47
48
    if (is_one(src_range[i]->extent))
      continue;
49
    Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
50
    idx++;
51
52
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
53
54
55
56
57
58
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
59
60
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
                                  int src_dst) const {
61
62
63
64
65
66
67
68
69
70
71
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
72
73
74
  ICHECK(idx == ivs.size())
      << "idx = " << idx << ", ivs.size() = " << ivs.size()
      << "src name = " << src->name << ", dst name = " << dst->name;
75
76
77
  return indices;
}

78
79
80
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
                             const Array<IterVar> &ivs, Array<PrimExpr> extents,
                             int src_dst) const {
81
82
83
84
85
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
86
87
    if (is_one(ranges[i]->extent))
      continue;
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
102
103
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
104
105
106
107
    return cond;
  }
}

108
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
109
  Array<IterVar> loop_vars = MakeIterVars();
110
111
112
113
114
115
  bool is_scalar = loop_vars.size() == 0;
  if (is_scalar) {
    return For(Var("i"), 0, 1, ForKind::kSerial,
               BufferStore(dst, BufferLoad(src, {0}), {0}));
  }

116
117
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
118

119
120
121
122
123
124
125
126
127
128
  ICHECK(loop_vars.size() <= src_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", src_range.size() = " << src_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  ICHECK(loop_vars.size() <= dst_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

129
130
131
132
133
134
135
  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  PrimExpr value = BufferLoad(src, src_indices);
136
137
138
139
  if (src->dtype != dst->dtype)
    value = Cast(dst->dtype, value);
  if (src_predicate.defined())
    value = if_then_else(src_predicate, value, make_zero(dst->dtype));
140
141

  Stmt body = BufferStore(dst, value, dst_indices);
142
143
  if (dst_predicate.defined())
    body = IfThenElse(dst_predicate, body);
144
145
  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
146
    if (coalesced_width.defined()) {
147
148
      annotations.Set("coalesced_width", coalesced_width);
    }
149
150
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body, NullOpt, annotations);
151
152
153
154
  }
  return Downcast<For>(body);
}

155
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
156
157
  Target target = T.target;
  bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
158
  Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
159
160
  if (ldsm_stmt.defined())
    return ldsm_stmt;
161
162

  Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
163
164
  if (bulk_copy_stmt.defined())
    return bulk_copy_stmt;
165
166
167
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));

168
169
170
  auto transformed_loop =
      Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));

171
  For vectorized_thread_loop;
172
  auto par_op = std::make_unique<ParallelOp>(transformed_loop);
173
174

  if (is_cpu_target) {
175
    vectorized_thread_loop = VectorizeLoop(transformed_loop);
176
  } else {
177
178
179
180
181
182
    std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
                                      InferLevel::kFree};
    for (auto level : levels) {
      par_op->InferLayout(
          {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
    }
183
184
185
186
    auto loop_layout = par_op->GetLoopLayout();
    auto thread_var = T.thread_var;
    auto thread_loop =
        PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
187
188
189
    vectorized_thread_loop = VectorizeLoop(thread_loop);
  }

190
  if (par_op->GetPredicate(T.thread_var).defined()) {
191
192
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
193
194
195
196
197
  }

  return vectorized_thread_loop;
}

198
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
199
200
201
202
203
204
205
206
207
208
209
210
211
212
  // Check buffer scope
  bool is_ldmatrix;
  if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
      dst.scope() == "local.fragment") {
    is_ldmatrix = true;
  } else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
             src.scope() == "local.fragment") {
    is_ldmatrix = false;
  } else {
    return Stmt();
  }

  // Check no predicates
  Array<IterVar> loop_vars = MakeIterVars();
213
214
215
216
  if (loop_vars.size() < 2)
    return Stmt();
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
217
218
  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
219
220
  if (src_predicate.defined() || dst_predicate.defined())
    return Stmt();
221
222
223
224
225
226

  Buffer shared_tensor = is_ldmatrix ? src : dst;
  Buffer local_tensor = is_ldmatrix ? dst : src;

  Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
  Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
227
228
  Array<PrimExpr> local_indices_transformed =
      local_layout->Forward(local_indices);
229
230
  local_tensor = T.buffer_remap[local_tensor];
  // currently only support 1-d case
231
232
  if (local_layout->OutputDim() != 1)
    return Stmt();
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

  Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
  Array<PrimExpr> shared_indices_transformed = shared_indices;
  Layout shared_layout;
  if (T.buffer_remap.count(shared_tensor)) {
    shared_layout = T.layout_map[shared_tensor];
    shared_tensor = T.buffer_remap[shared_tensor];
    shared_indices_transformed = shared_layout->Forward(shared_indices);
  }

  // Check local_layout follows 8x8 layout
  bool is_transposed;
  IterVar col_var = loop_vars[loop_vars.size() - 1];
  IterVar row_var = loop_vars[loop_vars.size() - 2];
  PrimExpr local_layout_thread_map =
      FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
249
  PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
250
      {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
251
252
253
254
255
  PrimExpr matrix_8x8_thread_map_trans =
      makeGemmFragment8x8Transposed()->ForwardThread(
          {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
  PrimExpr local_indices_flattened =
      local_tensor.OffsetOf(local_indices_transformed).back();
256
  if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
257
258
      IndiceCanVectorize(local_indices_flattened, col_var->var,
                         col_var->dom->extent, 2, analyzer)) {
259
    is_transposed = false;
260
261
262
263
  } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
                                     local_layout_thread_map) &&
             IndiceCanVectorize(local_indices_flattened, row_var->var,
                                row_var->dom->extent, 2, analyzer)) {
264
265
266
267
268
    is_transposed = true;
  } else {
    return Stmt();
  }
  // Check shared_layout is 16 bytes continuous
269
270
271
272
273
274
  if (shared_tensor->dtype.bytes() != 2)
    return Stmt();
  PrimExpr flattened_indice =
      shared_tensor.OffsetOf(shared_indices_transformed).back();
  if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
                          loop_vars.back()->dom->extent, 8, analyzer))
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    return Stmt();

  // Can only support local_range to be a full range
  for (size_t i = 0; i < dst_range.size(); i++) {
    if (!is_zero(dst_range[i]->min) ||
        !analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
      return Stmt();
  }

  // Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
  PrimExpr extent = local_tensor->shape[0];
  int num = 1;
  if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
    num = 4;
  else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
    num = 2;

  Array<PrimExpr> args;
293
  const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
294
295
296
297
298
299
300
  args.push_back(static_cast<int>(is_transposed));
  args.push_back(num);

  // Create shared address with regard to local address
  // if not transpose
  // coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
  // if transpose
301
302
  // coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
  // % 8 / 2)
303
304
305
306
307
  Var local_iter("i");
  Layout inv = local_layout->Inverse();
  Array<PrimExpr> shared_coords;
  PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
  if (!is_transposed)
308
309
310
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
         warp + FloorMod(T.thread_var, 8) * 4});
311
  else
312
313
314
315
316
317
318
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
             FloorMod(T.thread_var, 2),
         warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
  shared_coords.pop_back(); // remove rep
  if (shared_layout.defined())
    shared_coords = shared_layout->Forward(shared_coords);
319
  PrimExpr shared_addr = shared_tensor.access_ptr(
320
321
      is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
      shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
322
323
324
325
  args.push_back(shared_addr);

  if (is_ldmatrix) {
    // Can only support same dtype for ldmatrx
326
327
328
329
    if (local_tensor->dtype != shared_tensor->dtype)
      return Stmt();
    PrimExpr local_addr = local_tensor.access_ptr(
        2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
330
331
332
    args.push_back(local_addr);
  } else {
    for (int i = 0; i < num; i++) {
333
334
335
336
      PrimExpr value0 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
      PrimExpr value1 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
337
338
339
340
      if (local_tensor->dtype != shared_tensor->dtype) {
        value0 = Cast(shared_tensor->dtype, value0);
        value1 = Cast(shared_tensor->dtype, value1);
      }
341
      PrimExpr value_packed =
342
          Call(DataType::Int(32), pack_b16(), {value0, value1});
343
344
345
346
347
      args.push_back(value_packed);
    }
  }

  auto body = Evaluate(Call(DataType::Handle(), op, args));
348
349
  For for_node =
      For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
350
351
352
353
  for_node = LoopPragmaUnroll(for_node);
  return for_node;
}

354
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
355
356
357
358
359
  // Use parallel op to infer the layout
  if (par_op_ == nullptr) {
    arith::Analyzer analyzer;
    par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
  }
360
361
362
363
364
365
366
367
368
369
370
371
372
373
  if (T.layout_map.count(src) && T.layout_map.count(dst)) {
    // Only compare fragment layout
    if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
      const FragmentNode *src_layout = T.layout_map[src].as<Fragment>().get();
      const FragmentNode *dst_layout = T.layout_map[dst].as<Fragment>().get();
      if (src_layout && dst_layout) {
        ICHECK(src_layout->IsEqual(dst_layout, true))
            << "Get different layout for " << src << " and " << dst
            << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the layout";
      }
    }
  }
374
375
376
377
  return par_op_->InferLayout(T, level);
}

Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

  if (args[0]->IsInstance<BufferLoadNode>()) {
    auto buffer_load = Downcast<BufferLoad>(args[0]);
    for (const auto &index : buffer_load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        CHECK(ramp->stride.as<IntImmNode>()->value == 1)
            << "Only stride 1 ramps are supported";
        const auto *lanes = ramp->lanes.as<IntImmNode>();
        CHECK(lanes)
            << "Scalable vectors not supported in BufferRegion conversion";
        region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
      } else {
        region.push_back(Range::FromMinExtent(index, 1));
      }
    }
    dst = buffer_load->buffer;
  } else {
    dst = vmap[GetVarFromAccessPtr(args[0])];
    for (int i = 0; i < dst->shape.size(); i++) {
      region.push_back(Range(0, dst->shape[i]));
    }
  }

401
402
403
404
405
  if (args[1]->dtype != dst->dtype) {
    value = Cast(dst->dtype, args[1]);
  } else {
    value = args[1];
  }
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

  ICHECK(region.size() == dst->shape.size())
      << "region size = " << region.size() << " != " << dst->shape.size();
  for (int i = 0; i < region.size(); i++) {
    // bound check if region is static
    if (region[i]->min.as<IntImm>()) {
      int64_t min = Downcast<IntImm>(region[i]->min)->value;
      ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
    }
    if (region[i]->extent.as<IntImm>()) {
      int64_t extent = Downcast<IntImm>(region[i]->extent)->value;
      ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value)
          << "region[" << i << "] = " << extent << " > " << dst->shape[i];
    }
  }
421
422
}

423
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
424
425
426
427
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
428
    Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
429
    loop_vars.push_back({region[i], var, IterVarType::kDataPar});
430
431
432
433
    dst_indices.push_back(var);
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
434
435
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
436
437
438
439
  }
  return Downcast<For>(body);
}

440
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
441
442
443

  if (dst.scope() == "local.fragment") {
    auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
444
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
445
                        InferLevel::kFree);
446
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
447
448
449
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
450
451
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    if (par_op->GetPredicate(T.thread_var).defined()) {
452
453
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
454
455
456
457
458
459
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
    auto vectorized_thread_loop = VectorizeLoop(init_loop);
    return vectorized_thread_loop;
460
461
  } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
    auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
462
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
463
464
465
466
467
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    return vectorized_thread_loop;
468
469
470
471
472
473
474
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
  }
}

TIR_REGISTER_TL_OP(Copy, copy)
    .set_num_inputs(3)
475
476
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
477
478
479

TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
480
481
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
482

483
484
} // namespace tl
} // namespace tvm