elem.cc 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/elem.cc
 *
 * Define elment-wise operators.
 */

#include "elem.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
14
#include "../transform/common/loop_fusion_utils.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
    rgs[i] = region.GetRanges();
    bf[i] = region.GetBuffer();
  }
  std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
  std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
37
  if (args.size() >= 3) {
38
39
40
41
42
43
44
45
    coalesced_width = Downcast<IntImm>(args[2]);
  }
}

Array<IterVar> Copy::MakeIterVars() const {
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
46
47
    if (is_one(src_range[i]->extent))
      continue;
48
49
    Var var = Var(std::string{char('i' + idx)});
    idx++;
50
51
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
52
53
54
55
56
57
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
58
59
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
                                  int src_dst) const {
60
61
62
63
64
65
66
67
68
69
70
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
71
72
73
  ICHECK(idx == ivs.size())
      << "idx = " << idx << ", ivs.size() = " << ivs.size()
      << "src name = " << src->name << ", dst name = " << dst->name;
74
75
76
  return indices;
}

77
78
79
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
                             const Array<IterVar> &ivs, Array<PrimExpr> extents,
                             int src_dst) const {
80
81
82
83
84
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
85
86
    if (is_one(ranges[i]->extent))
      continue;
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
101
102
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
103
104
105
106
    return cond;
  }
}

107
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
108
  Array<IterVar> loop_vars = MakeIterVars();
109
110
111
112
113
114
  bool is_scalar = loop_vars.size() == 0;
  if (is_scalar) {
    return For(Var("i"), 0, 1, ForKind::kSerial,
               BufferStore(dst, BufferLoad(src, {0}), {0}));
  }

115
116
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
117
118
119
120
121
122
123
124

  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  PrimExpr value = BufferLoad(src, src_indices);
125
126
127
128
  if (src->dtype != dst->dtype)
    value = Cast(dst->dtype, value);
  if (src_predicate.defined())
    value = if_then_else(src_predicate, value, make_zero(dst->dtype));
129
130

  Stmt body = BufferStore(dst, value, dst_indices);
131
132
  if (dst_predicate.defined())
    body = IfThenElse(dst_predicate, body);
133
134
  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
135
    if (coalesced_width.defined()) {
136
137
      annotations.Set("coalesced_width", coalesced_width);
    }
138
139
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body, NullOpt, annotations);
140
141
142
143
  }
  return Downcast<For>(body);
}

144
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
145
146
  Target target = T.target;
  bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
147
  Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
148
149
  if (ldsm_stmt.defined())
    return ldsm_stmt;
150
151

  Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
152
153
  if (bulk_copy_stmt.defined())
    return bulk_copy_stmt;
154
155
156
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));

157
  For vectorized_thread_loop;
158
  auto par_op = std::make_unique<ParallelOp>(fused_loop);
159
160
161
162
163
164
165
166
167
168
169

  if (is_cpu_target) {
    vectorized_thread_loop = VectorizeLoop(fused_loop);
  } else {
    par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
    vectorized_thread_loop = VectorizeLoop(thread_loop);
  }

170
  if (par_op->GetPredicate(T.thread_var).defined()) {
171
172
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
173
174
175
176
177
  }

  return vectorized_thread_loop;
}

178
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  // Check buffer scope
  bool is_ldmatrix;
  if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
      dst.scope() == "local.fragment") {
    is_ldmatrix = true;
  } else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
             src.scope() == "local.fragment") {
    is_ldmatrix = false;
  } else {
    return Stmt();
  }

  // Check no predicates
  Array<IterVar> loop_vars = MakeIterVars();
193
194
195
196
  if (loop_vars.size() < 2)
    return Stmt();
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
197
198
  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
199
200
  if (src_predicate.defined() || dst_predicate.defined())
    return Stmt();
201
202
203
204
205
206

  Buffer shared_tensor = is_ldmatrix ? src : dst;
  Buffer local_tensor = is_ldmatrix ? dst : src;

  Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
  Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
207
208
  Array<PrimExpr> local_indices_transformed =
      local_layout->Forward(local_indices);
209
210
  local_tensor = T.buffer_remap[local_tensor];
  // currently only support 1-d case
211
212
  if (local_layout->OutputDim() != 1)
    return Stmt();
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

  Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
  Array<PrimExpr> shared_indices_transformed = shared_indices;
  Layout shared_layout;
  if (T.buffer_remap.count(shared_tensor)) {
    shared_layout = T.layout_map[shared_tensor];
    shared_tensor = T.buffer_remap[shared_tensor];
    shared_indices_transformed = shared_layout->Forward(shared_indices);
  }

  // Check local_layout follows 8x8 layout
  bool is_transposed;
  IterVar col_var = loop_vars[loop_vars.size() - 1];
  IterVar row_var = loop_vars[loop_vars.size() - 2];
  PrimExpr local_layout_thread_map =
      FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
229
  PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
230
      {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
231
232
233
234
235
  PrimExpr matrix_8x8_thread_map_trans =
      makeGemmFragment8x8Transposed()->ForwardThread(
          {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
  PrimExpr local_indices_flattened =
      local_tensor.OffsetOf(local_indices_transformed).back();
236
  if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
237
238
      IndiceCanVectorize(local_indices_flattened, col_var->var,
                         col_var->dom->extent, 2, analyzer)) {
239
    is_transposed = false;
240
241
242
243
  } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
                                     local_layout_thread_map) &&
             IndiceCanVectorize(local_indices_flattened, row_var->var,
                                row_var->dom->extent, 2, analyzer)) {
244
245
246
247
248
    is_transposed = true;
  } else {
    return Stmt();
  }
  // Check shared_layout is 16 bytes continuous
249
250
251
252
253
254
  if (shared_tensor->dtype.bytes() != 2)
    return Stmt();
  PrimExpr flattened_indice =
      shared_tensor.OffsetOf(shared_indices_transformed).back();
  if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
                          loop_vars.back()->dom->extent, 8, analyzer))
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    return Stmt();

  // Can only support local_range to be a full range
  for (size_t i = 0; i < dst_range.size(); i++) {
    if (!is_zero(dst_range[i]->min) ||
        !analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
      return Stmt();
  }

  // Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
  PrimExpr extent = local_tensor->shape[0];
  int num = 1;
  if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
    num = 4;
  else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
    num = 2;

  Array<PrimExpr> args;
273
  const Op &op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp();
274
275
276
277
278
279
280
  args.push_back(static_cast<int>(is_transposed));
  args.push_back(num);

  // Create shared address with regard to local address
  // if not transpose
  // coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
  // if transpose
281
282
  // coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
  // % 8 / 2)
283
284
285
286
287
  Var local_iter("i");
  Layout inv = local_layout->Inverse();
  Array<PrimExpr> shared_coords;
  PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
  if (!is_transposed)
288
289
290
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
         warp + FloorMod(T.thread_var, 8) * 4});
291
  else
292
293
294
295
296
297
298
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
             FloorMod(T.thread_var, 2),
         warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
  shared_coords.pop_back(); // remove rep
  if (shared_layout.defined())
    shared_coords = shared_layout->Forward(shared_coords);
299
  PrimExpr shared_addr = shared_tensor.access_ptr(
300
301
      is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
      shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
302
303
304
305
  args.push_back(shared_addr);

  if (is_ldmatrix) {
    // Can only support same dtype for ldmatrx
306
307
308
309
    if (local_tensor->dtype != shared_tensor->dtype)
      return Stmt();
    PrimExpr local_addr = local_tensor.access_ptr(
        2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
310
311
312
    args.push_back(local_addr);
  } else {
    for (int i = 0; i < num; i++) {
313
314
315
316
      PrimExpr value0 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
      PrimExpr value1 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
317
318
319
320
      if (local_tensor->dtype != shared_tensor->dtype) {
        value0 = Cast(shared_tensor->dtype, value0);
        value1 = Cast(shared_tensor->dtype, value1);
      }
321
322
      PrimExpr value_packed =
          Call(DataType::Int(32), PackB16Op(), {value0, value1});
323
324
325
326
327
      args.push_back(value_packed);
    }
  }

  auto body = Evaluate(Call(DataType::Handle(), op, args));
328
329
  For for_node =
      For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
330
331
332
333
  for_node = LoopPragmaUnroll(for_node);
  return for_node;
}

334
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
335
336
337
338
339
  // Use parallel op to infer the layout
  if (par_op_ == nullptr) {
    arith::Analyzer analyzer;
    par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
  }
340
341
342
343
344
345
346
347
348
349
350
351
352
353
  if (T.layout_map.count(src) && T.layout_map.count(dst)) {
    // Only compare fragment layout
    if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
      const FragmentNode *src_layout = T.layout_map[src].as<Fragment>().get();
      const FragmentNode *dst_layout = T.layout_map[dst].as<Fragment>().get();
      if (src_layout && dst_layout) {
        ICHECK(src_layout->IsEqual(dst_layout, true))
            << "Get different layout for " << src << " and " << dst
            << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the layout";
      }
    }
  }
354
355
356
357
358
359
360
361
362
363
364
365
  return par_op_->InferLayout(T, level);
}

Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
  dst = vmap[GetVarFromAccessPtr(args[0])];
  if (args[1]->dtype != dst->dtype) {
    value = Cast(dst->dtype, args[1]);
  } else {
    value = args[1];
  }
}

366
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
367
368
369
370
371
372
373
374
375
376
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
    Var var = Var(std::string{char('i' + i)});
    loop_vars.push_back({Range(0, dst->shape[i]), var, IterVarType::kDataPar});
    dst_indices.push_back(var);
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
377
378
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
379
380
381
382
  }
  return Downcast<For>(body);
}

383
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
384
385
386

  if (dst.scope() == "local.fragment") {
    auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
387
388
389
390
391
392
    par_op->InferLayout({T.target, T.block_size, T.layout_map},
                        InferLevel::kFree);
    par_op->InferLayout({T.target, T.block_size, T.layout_map},
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
393
394
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    if (par_op->GetPredicate(T.thread_var).defined()) {
395
396
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
397
398
399
400
401
402
403
404
405
406
407
408
409
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
    auto vectorized_thread_loop = VectorizeLoop(init_loop);
    return vectorized_thread_loop;
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
  }
}

TIR_REGISTER_TL_OP(Copy, copy)
    .set_num_inputs(3)
410
411
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
412
413
414

TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
415
416
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
417

418
419
} // namespace tl
} // namespace tvm